[HDU4742 Pinball Game 3D] 分治、KD树

传送门

三维LIS。即每个点有个三维坐标,两个点能放在一前一后当且仅当$$x_i < x_j, y_i < y_j, z_i < z_j$$,求最长的序列,并该条件下的方案数。

网络赛的题,比赛的时候一想二维的,直接上了KD树就过了。
赛后听别人说,都是树套树或者CDQ分治的。于是就去cxlove大大的博客学习了一下CDQ分治。
这种分治的大致思路是每次把区间分两半,先做左边,然后把左边对右边的影响(在这题里相当于dp值的转移)计算好,然后再右边递归下去算。
这种复杂度很容易计算,T(n) = 2T(n / 2) + O(xxxn),最终复杂度肯定是xxxn lg n(想想每个元素被碰的次数就好了)。
我目前对这种分治的理解大致有这么两条:

  1. 在第i号位置的元素其实被更新了popcount(i)次(popcount表示这个值在二进制下包含的1的个数),就是二进制下看,每有一个1,就会被对应长度的块更新一次。
  2. 若i < j,则在通过f[i]影响f[j]的时候,f[i]总是已经计算好了的。

在这题里显然先对x排序,然后变成一个二维查询的问题,利用这个思想结合树状数组就可以达到$$O(n lg^2 n)$$的复杂度。

KD树的做法就更为直观,直接按x排序后,动态求左上角的那个块的值就好了。大概是写得太奔放的原因,我的KD树比这个分治的要快

今天写的时候WA了好久,原因有两个

  1. 分治的时候sort了原数组在处理后要还原回去。
  2. 我是每个区间分别离散化的,结果离散化了以后,被离散化的值没有还原回去,导致在别的区间内再离散化的的时候,值的大小相对关系就不对了…

分治的代码:

#include <cstdio>
#include <cstring>
#include <set>
#include <map>
#include <vector>
#include <iostream>
#include <algorithm>
using namespace std;
#define rep(i,n) for (int i = 0; i < (int)(n); i++)
#define foreach(it,v) for (__typeof((v).end()) it = (v).begin(); it != (v).end(); it++)
typedef pair <int, unsigned int> PII;
const int N = 100005;
const unsigned int Mod = 1 << 30;
struct Point {
	int x, y, z, nz;
	PII res;
	void read() { scanf("%d%d%d", &x, &y, &z); }
}a[N];
PII Tr[N];

bool cmpx(Point a, Point b) {
	return a.x < b.x || (a.x == b.x && a.y < b.y) || (a.x == b.x && a.y == b.y && a.z < b.z);
}

bool cmpy(Point a, Point b) {
	return a.y < b.y || (a.y == b.y && a.z < b.z);
}

void up(PII &r, PII x) {
	if (x.first > r.first)
		r = x;
	else if (x.first == r.first)
		r.second += x.second;
}

void add(int i, int n, PII c) {
	for (i++, n++; i < n; i += i & -i)
		up(Tr[i], c);
}

PII get(int i) {
	PII res(0, 0);
	for (i++; i; i -= i & -i)
		up(res, Tr[i]);
	return res;
}

void gao(int l, int r) {
	if (l + 1 >= r) return;
	int mid = (l + r) / 2;
	gao(l, mid);
	sort(a + l, a + mid, cmpy);
	sort(a + mid, a + r, cmpy);
	vector <int> v;
	for (int i = l; i < r; i++) v.push_back(a[i].z);
	sort(v.begin(), v.end());
	for (int i = l; i < r; i++) a[i].nz = lower_bound(v.begin(), v.end(), a[i].z) - v.begin();
	fill(Tr + 1, Tr + r - l + 1, PII(0, 0));
	for (int j = l, i = mid; i < r; i++) {
		while (j < mid && a[j].y <= a[i].y) {
			add(a[j].nz, r - l, a[j].res);
			j++;
		}
		PII cur = get(a[i].nz);
		cur.first++;
		up(a[i].res, cur);
	}
	sort(a + mid, a + r, cmpx);
	gao(mid, r);
}

int main() {
	int Tc;
	scanf("%d", &Tc);
	while (Tc--) {
		int n;
		scanf("%d", &n);
		rep (i, n) a[i].read();
		rep (i, n) a[i].res = PII(1, 1);
		sort(a, a + n, cmpx);
		gao(0, n);
		PII res = PII(0, 1);
		rep (i, n) up(res, a[i].res);
		cout << res.first << " " << (res.second & (Mod - 1)) << endl;
	}
}

KD树的代码:

#include <cstdio>
#include <cstring>
#include <vector>
#include <iostream>
#include <algorithm>
using namespace std;
#define rep(i,n) for (int i = 0; i < (int)(n); i++)
typedef pair <int, int> PII;
typedef unsigned int ui;
const int N = 100005;
const int INF = 0x7FFFFFFF;
const ui mask = (1 << 30) - 1;
int Tc, n, m;
struct Node {
    PII e, sub, cur;
    int o;
    Node *lc, *rc;
}*C, pool[N];

struct TPoint {
    int x, y, z;
    TPoint() {}
    TPoint(int x, int y, int z): x(x), y(y), z(z) {}
}a[N];
PII b[N];

bool cmp(const TPoint &a, const TPoint &b) {
    return a.x < b.x || (a.x == b.x && a.y < b.y) || (a.x == b.x && a.y == b.y && a.z < b.z);
}

bool cmpX(const PII &a, const PII &b) {
    return a.first < b.first || (a.first == b.first && a.second < b.second);
}

bool cmpY(const PII &a, const PII &b) {
    return a.second < b.second || (a.second == b.second && a.first < b.first);
}

Node *build(PII *b, int l, int r, int o) {
    if (l >= r) return NULL;
    Node *p = C++;
    p->o = o;
    int mid = (l + r) / 2;
    nth_element(b + l, b + mid, b + r, o ? cmpY : cmpX);
    p->e = b[mid];
    p->cur = p->sub = PII(0, 0);
    p->lc = build(b, l, mid, o ^ 1);
    p->rc = build(b, mid + 1, r, o ^ 1);
    return p;
}

inline void update(PII &cur, const PII &v) {
    if (v.first > cur.first) {
        cur = v;
    } else if (cur.first == v.first) {
        cur.second = cur.second + v.second;
    }
}

void add(Node *p, const PII &e, const PII &v) {
    update(p->sub, v);
    if (e == p->e) {
        update(p->cur, v);
        return;
    } else {
        bool c = p->o ? cmpY(e, p->e) : cmpX(e, p->e);
        if (c) {
            add(p->lc, e, v);
        } else {
            add(p->rc, e, v);
        }
    }
}

PII ans;

void get(Node *p, const PII &e, int maxx = INF, int maxy = INF) {
    if (!p) return;
    if (p->sub.first < ans.first) return;
    if (maxx <= e.first && maxy <= e.second) {
        update(ans, p->sub);
    } else {
        if (p->e.first <= e.first && p->e.second <= e.second) update(ans, p->cur);
        if (p->o) {
            if (p->e.second <= e.second) get(p->rc, e, maxx, maxy);
            get(p->lc, e, maxx, min(maxy, p->e.second));
        } else {
            if (p->e.first <= e.first) get(p->rc, e, maxx, maxy);
            get(p->lc, e, min(maxx, p->e.first), maxy);
        }
    }
}

int main() {
#ifdef cwj
    freopen("E.in", "r", stdin);
#endif
    scanf("%d", &Tc);
    rep (ri, Tc) {
        scanf("%d", &n);
        rep (i, n) {
            scanf("%d%d%d", &a[i].x, &a[i].y, &a[i].z);
            b[i] = make_pair(a[i].y, a[i].z);
        }
        sort(b, b + n);
        m = unique(b, b + n) - b;
        C = pool;
        Node *root = build(b, 0, m, 0);
        sort(a, a + n, cmp);
        rep (i, n) {
            ans = PII(0, 0);
            get(root, PII(a[i].y, a[i].z));
            PII cur;
            if (ans.first == 0) {
                cur.first = 1;
                cur.second = 1;
            } else {
                cur.first = ans.first + 1;
                cur.second = ans.second;
            }
//            printf("cur %d %d %d\n", i, cur.first, cur.second);
            add(root, PII(a[i].y, a[i].z), cur);
        }
        printf("%d %d\n", root->sub.first, root->sub.second & mask);
    }
    return 0;
}

留下评论

您的邮箱地址不会被公开。 必填项已用 * 标注