D3T3 Three
Desciption:
给定一颗 $n$ 个节点的无根树,在树上选三个互不相同的节点,使得三个节点两两之间距离相等,输出方案数。
Solution:
首先,我不是很会 $dp$
其次,我完全不会长链剖分优化 $dp$
最后,我
其实 $dp$ 部分还是听懂了。
设 $f_{i,j}$ 表示以 $i$ 为根的子树中距离 $i$ 为 $j$ 的点数,$g_{i,j}$ 表示 $i$ 的子树中有多少两个点的 $lca$ 到 $i$ 的距离为 $d-j$ ,两个点到他们 $lca$ 的距离是 $d$ 。
很容易发现这两个状态可以互补(?
因此对于一对父子 $(p,v)$ 有如下转移式:
但是代码就完全就看不懂了,什么指针转移力(哭
Code:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
| #include<bits/stdc++.h> using namespace std; #define in read() #define ll long long inline int read(){ int x=0,f=1;char c=getchar(); while(c<'0'||c>'9'){ if(c=='-') f=-1; c=getchar(); } while(c>='0'&&c<='9'){ x=(x<<1)+(x<<3)+c-'0'; c=getchar(); } return x*f; } const int maxn=3e5+10; vector<int> e[maxn]; int n,u,v; ll pool[maxn<<4]; ll* top=pool; ll *f[maxn],*g[maxn]; ll* get(int len){ ll* t=top;top+=len; return t; } ll ans=0; int dep[maxn],hson[maxn]; void dfs1(int x,int fa){ dep[x]=0;hson[x]=0; for(int i=0;i<e[x].size();i++){ int v=e[x][i]; if(v==fa) continue; dfs1(v,x); dep[x]=max(dep[x],dep[v]+1); if(dep[v]>dep[hson[x]]) hson[x]=v; } } void dfs2(int x,int fa,int &maxlen,int blank){ maxlen=max(maxlen,dep[x]); if(hson[x]){ dfs2(hson[x],x,maxlen,blank+1); ans+=g[hson[x]][1]; f[x]=f[hson[x]]-1; f[x][0]=1; g[x]=g[hson[x]]+1; }else{ f[x]=get(maxlen+5+blank)+blank; g[x]=get(maxlen+5+blank); f[x][0]=1; } for(int i=0;i<e[x].size();i++){ int v=e[x][i],mlen=0; if(v==fa||v==hson[x]) continue; dfs2(v,x,mlen,0); for(int j=0;j<dep[v];j++) ans+=f[x][j]*g[v][j+1]; for(int j=1;j<=dep[v]+1;j++) ans+=g[x][j]*f[v][j-1]; for(int j=1;j<=dep[v]+1;j++) g[x][j]+=f[x][j]*f[v][j-1]; for(int j=0;j<=dep[v];j++) f[x][j+1]+=f[v][j]; for(int j=1;j<=dep[v];j++) g[x][j-1]+=g[v][j]; } } int main(){ while(1){ n=in; if(n==0) break; memset(pool,0,sizeof(pool)); memset(hson,0,sizeof(hson)); memset(dep,0,sizeof(dep)); for(int i=1;i<=n;i++) e[i].clear(); ans=0; for(int i=1;i<n;i++){ u=in;v=in; e[u].push_back(v); e[v].push_back(u); } int mxlen=0; dfs1(1,0); dfs2(1,0,mxlen,0); printf("%lld\n",ans); } return 0; }
|