[codeforces 44G] KDTree

题目
【题目大意】

给定若干个靶子(xl, xr, yl, yr, z),z为该靶子离射击位置的距离,所有靶子都可以看成是二维平面上平行于坐标轴的矩形。然后按顺序给定若干个子弹的射击位置(x, y),子弹射到一个靶子就会将靶子打碎,并掉落到地上。问每个子弹射到的靶子是谁。保证靶子的z值不相同。

【算法】
把子弹变成包含三个元素的点(x, y, id(就是说这是第几个子弹))。然后把靶子按z排序,从前往后扫,对于每个靶子,去找自己这个区域里面权值(id)最小的子弹,匹配上以后把这个点删掉。重复上面的操作,最后就能得到结果。
所以这里需要一个找矩形内权值最小的点的数据结构,并能支持删除操作。显然,KDTree是一个比较好的选择,查找操作是$$O(\sqrt{n})$$的,删除操作是$$O(\log n)$$的。

【时间复杂度】$$O(m \sqrt{n} )$$
【空间复杂度】$$O(n)$$
【代码】
WA了好多次,改得略丑。但是速度还是很快的。

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
using namespace std;
template  void checkmin(T &t,T x){if (x < t) t = x;}
template  void checkmax(T &t,T x){if (x > t) t = x;}
template  void _checkmin(T &t, T x){if (t == -1) t = x; if (x < t) t = x;}
template  void _checkmax(T &t, T x){if (t == -1) t = x; if (x > t) t = x;}
typedef pair  PII;
typedef pair  PDD;
typedef long long lld;
#define foreach(it,v) for (__typeof((v).begin()) it = (v).begin();it != (v).end();it++)
#define DEBUG(a) cout << #a" = " << (a) << endl;
#define DEBUGARR(a, n) for (int i = 0; i < (n); i++) { cout << #a"[" << i << "] = " << (a)[i] << endl; }
const int N = 105555;
const int INF = 1000000000;
int n;
struct point {
	int x, y, id;
	point() {}
	point(int x, int y, int id):x(x),y(y),id(id){}
};

struct query {
	int xl, xr, yl, yr, z, id;
	query(int xl, int xr, int yl, int yr, int z, int id):xl(xl), xr(xr), yl(yl), yr(yr), z(z), id(id) {}
};
vector  Q;

bool _cmp(query a, query b) { return a.z < b.z; }
bool _cmp2(point a, point b) { return a.id < b.id; }

namespace {
	int pt;
	struct Node {
		bool div;
		int size;
		int minvalue;
		point element;
		Node *lc, *rc;

		void update() {
			if (lc && rc) {
				minvalue = min(lc->minvalue, rc->minvalue);
			} else {
				minvalue = size ? element.id : INF;
			}
		}
	}pool[N * 4], *root;

	bool cmpX(point a, point b) { return a.x < b.x || (a.x == b.x && a.y < b.y) || (a.x == b.x && a.y == b.y && a.id < b.id); }
	bool cmpY(point a, point b) { return a.y < b.y || (a.y == b.y && a.x < b.x) || (a.x == b.x && a.y == b.y && a.id < b.id); }
	bool cmp(point a, point b, bool div) { return div ? cmpY(a, b) : cmpX(a, b); }

	Node *build(point *a, int l, int r, bool div) {
		int mid = (l + r) / 2;
		nth_element(a + l, a + mid, a + r + 1, div ? cmpY : cmpX);
		Node *ret = &pool[pt++];
		ret->div = div;
		ret->size = r - l + 1;
		ret->element = a[mid];
		if (l != r) {
			ret->lc = build(a, l, mid, !div);
			ret->rc = build(a, mid + 1, r, !div);
		} else {
			ret->lc = ret->rc = NULL;
		}
		ret->update();
		return ret;
	}

	void remove(Node *p, const point o) {
		p->size--;
		if (p->lc && p->rc) {
			if (cmp(p->element, o, p->div)) {
				remove(p->rc, o);
			} else {
				remove(p->lc, o);
			}
		}
		p->update();
	}

	int getMin(Node *p, int xl, int xr, int yl, int yr) {
		if (!p || !p->size) return INF;
		if (xl == -INF && xr == INF && yl == -INF && yr == INF) return p->minvalue;
		if (!p->lc && !p->rc) {
		    return xl <= p->element.x && p->element.x <= xr && yl <= p->element.y && p->element.y <= yr ? p->element.id : INF;
		}
		if (!p->div) {
			int ret = INF;
			if (xl <= p->element.x)
				checkmin(ret, getMin(p->lc, xl, xr < p->element.x ? xr : INF, yl, yr));
			if (xr >= p->element.x)
				checkmin(ret, getMin(p->rc, xl > p->element.x ? xl : -INF, xr, yl, yr));
			return ret;
		} else {
			int ret = INF;
			if (yl <= p->element.y)
				checkmin(ret, getMin(p->lc, xl, xr, yl, yr < p->element.y ? yr : INF));
			if (yr >= p->element.y)
				checkmin(ret, getMin(p->rc, xl, xr, yl > p->element.y ? yl : -INF, yr));
			return ret;
		}
	}
}

point a[N];
int ans[N];

int main() {
//    freopen("in", "r", stdin);
	scanf("%d", &n);
	for (int i = 0; i < n; i++) {
		int xl, xr, yl, yr, z;
		scanf("%d%d%d%d%d", &xl, &xr, &yl, &yr, &z);
		Q.push_back(query(xl, xr, yl, yr, z, i + 1));
	}
	sort(Q.begin(), Q.end(), _cmp);
	scanf("%d", &n);
	for (int i = 0; i < n; i++) {
		int x, y;
		scanf("%d%d", &x, &y);
		a[i] = point(x, y, i + 1);
	}
	pt = 0;
	root = build(a, 0, n - 1, 0);
	sort(a, a + n, _cmp2);
	foreach (it, Q) {
//	    printf("%d\n", it->id);
		int ret = getMin(root, it->xl, it->xr, it->yl, it->yr);
		if (ret != INF) {
			ans[ret] = it->id;
			remove(root, a[ret - 1]);
		}
	}
	for (int i = 1; i <= n; i++) {
		printf("%d\n", ans[i]);
	}
	return 0;
}

留下评论

您的电子邮箱地址不会被公开。 必填项已用 * 标注