SPOJ 726 大根堆+小根堆 || 平衡树

题目:

https://www.spoj.pl/problems/PRO/

题意描述:

总共N天,每天取已加入的数的最大值和最小值相减并累加到ans,并删掉这两个数。

数据保证每天至少有两个数可以删。

解题分析:

这题可以建两个堆,然后按题意模拟,但是用平衡树更方便。

我用的是SBT。比堆更简短。

插曲:

WA到2B去了。。。原来是要用INT64

code:

#include #include #include #define N 1000000
using namespace std;
struct sbttype{int l,r,w,s;}tr[N];
int a[N],tot=0,n,root=0;

void Left(int &p){
    int t=tr[p].r;
    tr[p].r=tr[t].l;
    tr[t].l=p;
    tr[p].s=tr[tr[p].l].s+tr[tr[p].r].s+1;
    p=t;
    tr[p].s=tr[tr[p].l].s+tr[tr[p].r].s+1;
}

void Right(int &p){
    int t=tr[p].l;
    tr[p].l=tr[t].r;
    tr[t].r=p;
    tr[p].s=tr[tr[p].l].s+tr[tr[p].r].s+1;
    p=t;
    tr[p].s=tr[tr[p].l].s+tr[tr[p].r].s+1;
}

void repair(int &p){
    bool flag=false;
    if (tr[tr[tr[p].l].l].s>tr[tr[p].r].s){
        Right(p);
        flag=true;
    }
    if (tr[tr[tr[p].l].r].s>tr[tr[p].r].s){
        Left(tr[p].l);
        Right(p);
        flag=true;
    }
    if (tr[tr[tr[p].r].r].s>tr[tr[p].l].s){
        Left(p);
        flag=true;
    }
    if (tr[tr[tr[p].r].l].s>tr[tr[p].l].s){
        Right(tr[p].r);
        Left(p);
        flag=true;
    }
    if (flag){
        repair(tr[p].l);
        repair(tr[p].r);
        repair(p);
    }
}

void ins(int &p,int w){
    if (p==0){
        tr[++tot].l=0;
        tr[tot].r=0;
        tr[tot].s=1;
        tr[tot].w=w;
        p=tot;
        return;
    }
    if (w<=tr[p].w) ins(tr[p].l,w);
               else ins(tr[p].r,w);
    tr[p].s=tr[tr[p].l].s+tr[tr[p].r].s+1;
    repair(p);
}

int del(int &p,int w){
    tr[p].s–;
    if (tr[p].w==w || tr[p].l==0 && w<=tr[p].w || tr[p].r==0 && w>tr[p].w){
        int delnum=tr[p].w;
        if (tr[p].l==0 || tr[p].r==0) p=tr[p].l+tr[p].r;
        else tr[p].w=del(tr[p].l,0x7FFFFFFF);
        return delnum;
    }
    if (w<=tr[p].w) return del(tr[p].l,w);
               else return del(tr[p].r,w);
}

int getmax(int k){
    int p=k;
    while (tr[p].r!=0) p=tr[p].r;
    return tr[p].w;
}

int getmin(int k){
    int p=k;
    while (tr[p].l!=0) p=tr[p].l;
    return tr[p].w;
}

int main(){
      while (cin >> n){
        tot=0; root=0;
        long long ans=0;
        for (int i=1;i<=n;i++){
            int num;
            scanf("%d",&num);
            for (int j=1;j<=num;j++){
                int tmp;
                scanf("%d",&tmp);
                ins(root,tmp);
            }
            int tx=getmax(root),ty=getmin(root);
            ans+=(tx-ty);
            del(root,tx);
            del(root,ty);
        }
        cout << ans << endl;
      }
    return 0;
}

留下评论

您的邮箱地址不会被公开。 必填项已用 * 标注