感觉这个也算是个知识点但是到处都没有教程捏
这里有一篇写的很好的关于dfs序的教程。
dfs序就是把树上问题转化为序列问题,从而通过树状数组或线段树进行维护。
dfs序本身不难理解,难的是与树状数组以及线段树结合,再加上巨大的码量。一下的四个例题码量平均100+。
多码预警
下面用一些例题讲解。
1.单点修改,子树查询
在dfs序中,$x$ 的子树在两次 $x$ 出现的中间,且是连续的。
所以问题就转化为序列上单点修改,区间查询。树状数组维护即可。
code:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
| #include<bits/stdc++.h> using namespace std; #define in read() #define ll long long inline ll read(){ ll x=0,f=1;char c=getchar(); while(c>'9'||c<'0'){ if(c=='-') f=-1; c=getchar(); } while(c>='0'&&c<='9'){ x=(x<<1)+(x<<3)+c-'0'; c=getchar(); } return x*f; } const int maxn=1e6+10; int n,m,r,u,v,opt; struct edge{ int u,v,nxt; }e[maxn<<1]; int cnt=0,h[maxn]; void add(int u,int v){ e[++cnt]=(edge){u,v,h[u]}; h[u]=cnt; } ll val[maxn],c[maxn<<1]; int lowbit(int x){ return x&(-x); } void update(int x,int v){ while(x<=n){ c[x]+=v; x+=lowbit(x); } } ll query(int x){ ll res=0; while(x){ res+=c[x]; x-=lowbit(x); } return res; } int st[maxn],ed[maxn],tim=0; bool vis[maxn]; void dfs(int x){ st[x]=++tim;vis[x]=true; for(int i=h[x];i;i=e[i].nxt){ int v=e[i].v; if(!vis[v]) dfs(v); } ed[x]=tim; } int main(){ n=in;m=in;r=in; for(int i=1;i<=n;i++){ val[i]=in; } for(int i=1;i<n;i++){ u=in;v=in; add(u,v);add(v,u); } dfs(r); for(int i=1;i<=n;i++) update(st[i],val[i]); while(m--){ opt=in; if(opt==1){ u=in;v=in; update(st[u],v); } if(opt==2){ u=in; printf("%lld\n",query(ed[u])-query(st[u]-1)); } } return 0; }
|
2.子树修改,子树查询
同理,问题可以转化为区间修改区间查询,很容易想到线段树对伐。然后就有了以下代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
| #include<bits/stdc++.h> using namespace std; #define in read() #define ll long long inline int read(){ int x=0,f=1;char c=getchar(); while(c>'9'||c<'0'){ if(c=='-') f=-1; c=getchar(); } while(c>='0'&&c<='9'){ x=(x<<1)+(x<<3)+c-'0'; c=getchar(); } return x*f; } const int maxn=1e6+10; struct tree{ int l,r; ll sum,tag; }t[maxn<<1]; int n,m,r,u,v,opt; struct edge{ int u,v,nxt; }e[maxn<<1]; int cnt=0,h[maxn]; void add(int u,int v){ e[++cnt]=(edge){u,v,h[u]}; h[u]=cnt; } int a[maxn<<1],val[maxn]; void build(int p,int l,int r){ t[p].l=l;t[p].r=r;t[p].tag=0; if(l==r){ t[p].sum=a[l]; return ; } int mid=l+r>>1; build(p<<1,l,mid); build(p<<1|1,mid+1,r); t[p].sum=t[p<<1].sum+t[p<<1|1].sum; } void pushdown(int p){ if(t[p].tag){ t[p<<1].sum+=t[p].tag*(t[p<<1].r-t[p<<1].l+1); t[p<<1|1].sum+=t[p].tag*(t[p<<1|1].r-t[p<<1|1].l+1); t[p<<1].tag+=t[p].tag; t[p<<1|1].tag+=t[p].tag; t[p].tag=0; } } void update(int l,int r,int p,int val){ if(l<=t[p].l&&t[p].r<=r){ t[p].sum+=val*(t[p].r-t[p].l+1); t[p].tag+=val; return ; } pushdown(p); int mid=t[p].l+t[p].r>>1; if(l<=mid) update(l,r,p<<1,val); if(r>mid) update(l,r,p<<1|1,val); t[p].sum=t[p<<1].sum+t[p<<1|1].sum; } ll query(int l,int r,int p){ if(l<=t[p].l&&t[p].r<=r) return t[p].sum; pushdown(p); int mid=t[p].l+t[p].r>>1; ll ans=0; if(l<=mid) ans+=query(l,r,p<<1); if(mid<r) ans+=query(l,r,p<<1|1); return ans; } int st[maxn],ed[maxn],tim=0; bool vis[maxn]; void dfs(ll x){ st[x]=++tim;vis[x]=true; for(int i=h[x];i;i=e[i].nxt){ ll v=e[i].v; if(!vis[v]) dfs(v); } ed[x]=tim; } int main(){ n=in;m=in;r=in; for(int i=1;i<=n;i++) val[i]=in; for(int i=1;i<n;i++){ u=in;v=in; add(u,v);add(v,u); } dfs(r); for(int i=1;i<=n;i++) a[st[i]]=val[i]; build(1,1,n); while(m--){ opt=in; if(opt==1){ u=in;v=in; update(st[u],ed[u],1,v); } if(opt==2){ u=in; printf("%lld\n",query(st[u],ed[u],1)); } } return 0; }
|
被WA、RE、MLE轮流搞废的线段树。
为什么呢?一看数据 $1e6$ 。
还是得转过头来考虑用树状数组维护。
设原数组为 $a_i$ , 差分数组为 $d_i$
则有: $a_x=\sum_{i=1}^x d_i$
要求:$\sum_{i=1}^{x}a_i$
联立得:$ans=\sum_{i=1}^x\sum_{j=1}^id_j=\sum_{i=1}^x(x-i+1)d_i=(x+1)\sum_{i=1}^x-\sum_{i=1}^x id_i$
然后用两个树状数组分别维护即可。
code:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
| #include<bits/stdc++.h> using namespace std; #define in read() #define ll long long inline ll read(){ ll x=0,f=1;char c=getchar(); while(c>'9'||c<'0'){ if(c=='-') f=-1; c=getchar(); } while(c>='0'&&c<='9'){ x=(x<<1)+(x<<3)+c-'0'; c=getchar(); } return x*f; } const int maxn=1e6+10; ll n,m,r,u,v,opt; struct edge{ ll u,v,nxt; }e[maxn<<1]; ll cnt=0,h[maxn]; void add(ll u,ll v){ e[++cnt]=(edge){u,v,h[u]}; h[u]=cnt; } ll val[maxn],c[maxn<<1][2],sum[maxn]; ll lowbit(ll x){ return x&(-x); } void update(ll x,ll v,ll k){ while(x<=n){ c[x][k]+=v; x+=lowbit(x); } } ll query(ll x,ll k){ ll res=0; while(x){ res+=c[x][k]; x-=lowbit(x); } return res; } int st[maxn],ed[maxn],tim=0,dfn[maxn<<2]; bool vis[maxn]; void dfs(ll x,ll fa){ st[x]=++tim; dfn[tim]=x; for(int i=h[x];i;i=e[i].nxt){ int v=e[i].v; if(v!=fa) dfs(v,x); } ed[x]=tim; } int main(){ n=in;m=in;r=in; for(int i=1;i<=n;i++) val[i]=in; for(int i=1;i<n;i++){ u=in;v=in; add(u,v);add(v,u); } dfs(r,-1); for(int i=1;i<=n;i++) sum[i]=sum[i-1]+val[dfn[i]]; while(m--){ opt=in; if(opt==1){ u=in;v=in; update(st[u],v,0); update(ed[u]+1,-v,0); update(st[u],st[u]*v,1); update(ed[u]+1,-v*(ed[u]+1),1); } if(opt==2){ u=in; ll l=st[u], r=ed[u]; ll ans=sum[r]-sum[l-1]; printf("%lld\n",(r+1)*query(r,0)- l*query(l-1,0)- query(r,1)+ query(l-1,1)+ans); } } return 0; }
|
3.链上修改,单点查询,子树查询(树上差分)
单点查询和子树查询前文已经提过了
对于链上修改,每个点维护一个树上前缀和,那么答案就是 $val_x+val_y-val_{lca}-val_{fa_lca}$,这就是所谓的差分 。可以开两个树状数组,一个用来维护点权 $val_x$ ,一个用来维护前缀和。
对于单点查询和子树查询也可以用第一个树状数组解决。
题中还要求 $lca$,这里可以用欧拉序+$ST$ 表来求解,详细在代码中可以看到。
code:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
| #include<bits/stdc++.h> using namespace std; #define in read() #define ll long long inline int read(){ int x=0,f=1;char c=getchar(); while(c>'9'||c<'0'){ if(c=='-') f=-1; c=getchar(); } while(c>='0'&&c<='9'){ x=(x<<1)+(x<<3)+c-'0'; c=getchar(); } return x*f; } const int maxn=1e6+10; int n,m,r,u,v,opt; struct edge{ int u,v,nxt; }e[maxn<<1]; int cnt=0,h[maxn]; void add(int u,int v){ e[++cnt]=(edge){u,v,h[u]}; h[u]=cnt; } int val[maxn]; struct bittree{ ll c[maxn<<1];int siz; inline int lowbit(int x){ return x&-x; } void update(int k,ll v){ if(!k)return; while(k<=siz){ c[k]+=v;k+=lowbit(k); } } ll query(int k){ ll ans=0; while(k){ ans+=c[k];k-=lowbit(k); } return ans; } }; bittree A,B; int st[maxn],ed[maxn],tim=0,dfn[maxn<<1],dep[maxn],fath[maxn]; void dfs(int x,int fa,int d){ st[x]=++tim; dfn[tim]=x; dep[x]=d; fath[x]=fa; for(int i=h[x];i;i=e[i].nxt){ int v=e[i].v; if(v==fa) continue; dfs(v,x,d+1); dfn[++tim]=x; } ed[x]=tim; } int f[maxn<<1][24]; int mmin(int a,int b){ return dep[a]<dep[b]?a:b; } void STinit(int len){ for(int i=1;i<=len;i++) f[i][0]=dfn[i]; for(int j=1;(1<<j)<=len;j++){ for(int i=1;i+(1<<j-1)<len;i++){ f[i][j]=mmin(f[i][j-1],f[i+(1<<j-1)][j-1]); } } } int lca(int x,int y){ x=st[x];y=st[y]; if(x>y) swap(x,y); int k=0; while((1<<k+1)<=(y-x+1)) k++; return mmin(f[x][k],f[y-(1<<k)+1][k]); } signed main(){ n=in;m=in;r=in; for(int i=1;i<=n;i++) val[i]=in; for(int i=1;i<n;i++){ u=in;v=in; add(u,v); add(v,u); } dfs(r,0,0); A.siz=B.siz=tim; STinit(tim); for(int i=1;i<=n;i++){ A.update(st[i],(ll)val[i]*(dep[i]+1)); A.update(st[fath[i]],(ll)-val[i] *(dep[fath[i]]+1)); B.update(st[i],(ll)val[i]); B.update(st[fath[i]],(ll)-val[i]); } while(m--){ cin>>opt; if(opt==1){ int x=in,y=in;int p=in; int LCA=lca(x,y); A.update(st[x],(ll)p*(dep[x]+1)); A.update(st[y],(ll)p*(dep[y]+1)); A.update(st[LCA],(ll)-p*(dep[LCA]+1)); A.update(st[fath[LCA]],(ll)-p*(dep[fath[LCA]]+1)); B.update(st[x],(ll)p); B.update(st[y],(ll)p); B.update(st[LCA],(ll)-p); B.update(st[fath[LCA]],(ll)-p); } if(opt==2){ u=in; printf("%lld\n",B.query(ed[u])-B.query(st[u]-1)); } if(opt==3){ u=in; printf("%lld\n",A.query(ed[u])-A.query(st[u]-1)-(ll)dep[u]*(B.query(ed[u])-B.query(st[u]-1))); } } return 0; }
|
4.单点修改,子树修改,链上查询
解法 :分开考虑两个修改问题,所以先考虑分成修改点权+查询链和修改子树+查询链两个子问题。
子问题 1 :修改点权+查询链可以通过维护每个点到根这条链的点权和来实现。这样,修改点权相当于修改子树,查询链相当于查询点权,然后通过树上差分实现查询任意链的点权和。如对于修改点 $x$ ,则修改 $st[x]+v$ , $[ed[x]+1]-v$,这样,查询x即得到 $[root,x]$ 的路径值。
子问题 2 :修改子树+查询链,通过维护每个点x 到根这条链的点权和 $v[x]$ 。设 $y$ 在 $x$ 子树中,询问点为 $y$ ,$dep[x]$ 表示 $x$ 的深度,对于 $x$ 点权修改,对 $y$ 的影响为 $w(dep[y]-dep[x]+1)=wdep[y] – w(dep[x]-1)$,发现第二项与 $y$ 无关。$-w(dep[x]-1)$ 用一颗树状数组维护,同时与子问题1一起使用。w值以一颗树状数组维护。
code:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
| #include<bits/stdc++.h> using namespace std; const int N=1e6+10; #define ll long long int dep[N],st[N],ed[N],fa[N][22]; int n,m,r,cnt=0; int a[N]; struct Bit{ ll c[N];int siz; inline int lowbit(int x){ return x&-x; } Bit(int siz=0):siz(siz){ memset(c,0,sizeof(c)); } void add(int k,ll v){ if(!k)return; while(k<=siz){ c[k]+=v;k+=lowbit(k); } } ll sum(int k){ ll ans=0; while(k){ ans+=c[k];k-=lowbit(k); } return ans; } }w1,w2; struct node{ int u,v; }g[N*2]; int nxt[N*2],first[N],tot=0; inline void add(int u,int v){ g[++tot].u=u;g[tot].v=v; nxt[tot]=first[u];first[u]=tot; } int lca(int u,int v){ if(dep[u]<dep[v])swap(u,v); for(int i=20;i>=0;i--){ if(dep[fa[u][i]]>=dep[v])u=fa[u][i]; if(u==v)return u; } for(int i=20;i>=0;i--){ if(fa[u][i]!=fa[v][i]){ u=fa[u][i]; v=fa[v][i]; } } return fa[u][0]; } void dfs(int u,int f){ st[u]=++cnt; dep[u]=dep[f]+1; fa[u][0]=f; for(int i=1;i<21;i++){ fa[u][i]=fa[fa[u][i-1]][i-1]; } for(int i=first[u];i;i=nxt[i]){ int v=g[i].v; if(f==v)continue; dfs(v,u); } ed[u]=cnt; } ll query(int x){ return w1.sum(st[x])+w2.sum(st[x])*dep[x]; } ll ask(int x,int y){ int LCA=lca(x,y); return query(x)+query(y)-query(LCA)-query(fa[LCA][0]); } int main(){ scanf("%d%d%d",&n,&m,&r); w1.siz=w2.siz=n; for(int i=1;i<=n;i++)scanf("%d",a+i); for(int i=1;i<n;i++){ int u,v; scanf("%d%d",&u,&v); add(u,v);add(v,u); } dfs(r,0); for(int i=1;i<=n;i++){ w1.add(st[i],(ll)a[i]); w1.add(ed[i]+1,(ll)-a[i]); } while(m--){ int op,x,y; scanf("%d%d%d",&op,&x,&y); if(op==1){ w1.add(st[x],(ll)y); w1.add(ed[x]+1,(ll)-y); }else if(op==2){ w1.add(st[x],(ll)-y*(dep[x]-1)); w1.add(ed[x]+1,(ll)y*(dep[x]-1)); w2.add(st[x],(ll)y); w2.add(ed[x]+1,(ll)-y); }else{ printf("%lld\n",ask(x,y)); } } return 0; }
|