[WC2019]数树(树形dp+多项式exp)

[WC2019]数树(树形dp+多项式exp)

Part1

相同边连接的点同一颜色,直接模拟即可

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
namespace pt1{
int fa[N],sz[N];
map <int,int> M[N];
int Find(int x){ return fa[x]==x?x:fa[x]=Find(fa[x]); }
void Solve(){
rep(i,1,n) fa[i]=i;
rep(i,2,n){
int x=rd(),y=rd();
if(x>y) swap(x,y);
M[x][y]=1;
}
rep(i,2,n) {
int x=rd(),y=rd();
if(x>y) swap(x,y);
if(M[x][y]) fa[Find(x)]=Find(y);
}
int ans=1;
rep(i,1,n) if(Find(i)==i) ans=1ll*ans*y%P;
printf("%d\n",ans);
}
}

Part2

相同边连接的点同一颜色,即在相同边构成的树上形成了若干联通块

很容易想到可以强制一些边保留,设保留 \(i\) 条边的方案数是 \(F_i\) ,则答案就是 \(\sum_i F_i\cdot y^{n-i}\)

考虑 \(dp\) 那些边相同,但是不好直接计算剩下边不同的方案,所以考虑计算最多有 \(i\) 条边相同的方案数,即

\[G_i=\sum_{j=i}C(j,i)F_j\]

二项式反演得到 \(F_i=\sum_{j=i}(-1)^{j-i}C(j,i)G_j\)

设分成了 \(m\) 个联通块,大小分别为 \(size_i\) ,则这些联通块随意构成树的方案数就是 \(n^{m-2}\cdot\prod size_i\)

根据上述性质可以写出一个简单的 \(O(n^4)\) 树形dp求得 \(G_i\) ,即 \(dp[i][j][k]\) 表示在 \(i\) 的子树里,有 \(j\) 条边相同,当前还剩下一个大小为 \(k\) 的联通块,每多转移一条相同边,系数是 \(\frac{1}{ny}\)

考虑优化 \(dp\)

联通块大小的问题,可以转化为每次在联通块里选择一个关键点的方案数, \(dp\) 第三维 \(0/1\) 表示当前联通块里是否已经选出了关键点

每次断开一个联通块时必须已经存在关键点

答案是

\(\sum_i F_i\cdot y^{n-i}\)

\(=\sum_i y^{n-i} \sum_{j=i}(-1)^{j-i}C(j,i)G_j\)

\(=y^n G_j\sum_{i=0}^j(-1)^{j-i}C(j,i)y^{-i}\)

发现右边的式子 \(\sum_0^j(-1)^{j-i}C(j,i)y^{-i}=(\frac{1}{y}-1)^j\)

那么直接把 \(\frac{1}{y}-1\) 带入作为保留一条边的转移系数,消去了第二维

那么这个 \(\text{dp}\) 可以被优化到 \(O(n)\)

\[ \ \]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
namespace pt2{
vector <int> G[N];
int dp[N][2],g[2],Inv;
void dfs(int u,int f){
dp[u][0]=dp[u][1]=1;
for(int v:G[u]) if(v!=f) {
dfs(v,u);
g[0]=g[1]=0;
rep(i,0,1) rep(j,0,1) {
if(!i||!j) g[i|j]=(g[i|j]+1ll*dp[u][i]*dp[v][j]%P*Inv)%P;
if(j) g[i]=(g[i]+1ll*dp[u][i]*dp[v][j])%P;
}
dp[u][0]=g[0],dp[u][1]=g[1];
}
}
void Solve() {
rep(i,2,n) {
int u=rd(),v=rd();
G[u].pb(v),G[v].pb(u);
}
Inv=(qpow(y)-1)*qpow(n)%P;
dfs(1,0);
ll res=dp[1][1]*qpow(y,n)%P*qpow(n,P+n-3)%P;
printf("%lld\n",res);
}
}

Part3

有了上面的 \(dp\) ,这一部分就简单多了,设分成了 \(m\) 个联通块,每个大小为 \(a_i\) ,则贡献为

\[\begin{aligned}\frac{n!\cdot a_i^{a_i-2}\cdot (n^{m-2})^2(\frac{1}{y}-1)^{n-m}(\frac{1}{n}^{n-m})^2\cdot a_i^2}{\prod a_i! m !}\end{aligned}\]

即枚举每个联通块生成树的数量,且需要考虑两棵树分别的联通块之间的连边数量,这一部分需要平方

很显然,可以直接对于 \([x^i]F(x)=\frac{1}{i!}\cdot (\frac{1}{n^2}\cdot (\frac{1}{y}-1))^{i-1} i^2i^{i-2}\) 这个多项式求exp得到

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
const int M=1<<18|10,K=17;
typedef vector <int> Poly;

int w[M],rev[M],Inv[M];
void Init(){
ll t=qpow(3,(P-1)>>K>>1);
w[1<<K]=1;
rep(i,(1<<K)+1,(1<<(K+1))-1) w[i]=w[i-1]*t%P;
drep(i,(1<<K)-1,1) w[i]=w[i<<1];
Inv[0]=Inv[1]=1;
rep(i,2,M-1) Inv[i]=1ll*(P-P/i)*Inv[P%i]%P;
}
int Init(int n){
int R=1,cc=-1;
while(R<n) R<<=1,cc++;
rep(i,1,R-1) rev[i]=(rev[i>>1]>>1)|((i&1)<<cc);
return R;
}

void NTT(int n,Poly &a,int f){
if((int)a.size()<n) a.resize(n);
rep(i,1,n-1) if(rev[i]<i) swap(a[i],a[rev[i]]);
for(int i=1;i<n;i<<=1) {
int *e=w+i;
for(int l=0;l<n;l+=i*2){
for(int j=l;j<l+i;++j){
int t=1ll*a[j+i]*e[j-l]%P;
a[j+i]=a[j]-t,Mod2(a[j+i]);
a[j]+=t,Mod1(a[j]);
}
}
}
if(f==-1) {
reverse(a.begin()+1,a.end());
rep(i,0,n-1) a[i]=1ll*a[i]*Inv[n]%P;
}
}

Poly operator * (Poly a,Poly b){
int n=a.size(),m=b.size();
int R=Init(n+m-1);
NTT(R,a,1),NTT(R,b,1);
rep(i,0,R-1) a[i]=1ll*a[i]*b[i]%P;
NTT(R,a,-1),a.resize(n+m-1);
return a;
}

Poly Poly_Inv(Poly a){
int n=a.size();
if(n==1) return {(int)qpow(a[0])};
Poly b=a; b.resize((n+1)/2),b=Poly_Inv(b);
int R=Init(n*2);
NTT(R,a,1),NTT(R,b,1);
rep(i,0,R-1) a[i]=1ll*b[i]*(2-1ll*a[i]*b[i]%P+P)%P;
NTT(R,a,-1); a.resize(n);
return a;
}

Poly Deri(Poly a){
rep(i,1,a.size()-1) a[i-1]=1ll*i*a[i]%P;
a.pop_back();
return a;
}
Poly IDeri(Poly a){
a.pb(0);
drep(i,a.size()-2,0) a[i+1]=1ll*a[i]*Inv[i+1]%P;
a[0]=0;
return a;
}

Poly Ln(Poly a){
int n=a.size();
a=Deri(a)*Poly_Inv(a),a.resize(n+1);
return IDeri(a);
}

Poly Exp(Poly a){
int n=a.size();
if(n==1) return Poly{1};
Poly b=a; b.resize((n+1)/2),b=Exp(b);
b.resize(n); Poly c=Ln(b);
rep(i,0,n-1) c[i]=a[i]-c[i],Mod2(c[i]);
c[0]++,c=c*b;
c.resize(n);
return c;
}

void Solve() {
int I=(qpow(y)-1)*qpow(1ll*n*n%P)%P;
Init();
Poly F(n+1);
for(int i=1,FInv=1;i<=n;FInv=1ll*FInv*Inv[++i]%P){
F[i]=qpow(I,(i-1)) * // 保留i-1条边
(i==1?1:qpow(i,i-2))%P // i个点生成树
* i%P * i%P //
* FInv%P; // 阶乘常数
}
F=Exp(F);
rep(i,1,n) F[n]=1ll*F[n]*i%P;
ll res=F[n]*qpow(y,n)%P*qpow(n,2*(P+n-3))%P;
printf("%lld\n",res);
}