题目描述

Bob has a tree with $n$ nodes, the set of the edges of this tree is $T$.

Let $B$​ denote the edge set of n-clique, formally $B=\{(i,j)∣1≤i<j≤n\}$​​

Now give you an integer $k$, you need to find the number of pair $(X,Y)$ satisfies the following conditions:

  1. $X\subseteq T, Y\subseteq B$​.

  2. $|X|=n-1-k,|Y|=k$.

  3. $X\cup Y$ is an edge set of a tree with $n$ nodes.

    The answer may be very large, you only need to output the answer module $998244353$​.

输入描述:

The first line has two integers $n,k$.

Then there are $n-1$ lines, each line has two integers $u,v$ denote an edge $(u,v)$ in $T$.

$2\leq n\leq 5\times 10^4$

$1\leq k\leq min(100,n-1)$

简明题意

给定一棵树, 删 k 条边再加 k 条边使得它还是一棵树, 求方案数。

题解

对于任意一种删除树上 $k$​ 条边的方案均会将原来的树分为 $k+1$​​ 个连通块,假设其连通块大小分别为 $s_1,s_2\dots s_{k+1}$,那么有结论加上 $k$ 条边使其联通成一棵树的方案数为

证明可参见OI-wiki/prufer序列

对于 $n^{k+1-2}$ 我们可以提出来不管,剩下的问题转化为求所有删 $k$​ 条边方案下连通块大小乘积的累和,但这并不好求,而枚举所有删边方案计算也不现实。

可以考虑其组合意义,上述问题等价于求删除树上 $k$ 条边,并且在得到的 $k+1$ 个连通块中每个连通块取一个点的方案数,该问题可以用树形dp解决。

设 $dp[i][j][0/1]$​ 为对于以 $i$​​ 为根的子树,其中共删了 $j$ 条边,已经取了 $0/1$ 个点的方案数进行树形dp即可,时间复杂度 $O(nk)$.

代码

1
#include <iostream>
2
#include <cstdio>
3
#include <vector>
4
#include <cstring>
5
using namespace std;
6
const long long mod = 998244353;
7
const long long N = 5e4 + 7;
8
vector<vector<int> > e;
9
long long dp[N][105][2];
10
long long tmp[105][2];
11
int n, k;
12
int sze[N];
13
void dfs(int x, int fa){
14
    dp[x][0][1] = dp[x][0][0] = 1;
15
    sze[x] = 1;
16
    for(auto v: e[x]){
17
        if(v == fa) continue;
18
        dfs(v, x);
19
        memcpy(tmp, dp[x], sizeof(tmp));
20
        memset(dp[x], 0, sizeof(dp[x]));
21
        for(int i = 0; i <= k && i < sze[x]; i++){
22
            for(int j = 0; i + j <= k && j < sze[v]; j++){
23
                dp[x][i + j][0] = (dp[x][i + j][0] + tmp[i][0] * dp[v][j][0]) % mod;
24
                if(i + j != k) dp[x][i + j + 1][0] = (dp[x][i + j + 1][0] + tmp[i][0] * dp[v][j][1]) % mod;
25
                dp[x][i + j][1] = (dp[x][i + j][1] + tmp[i][1] * dp[v][j][0] + tmp[i][0] * dp[v][j][1]) % mod;
26
                if(i + j != k) dp[x][i + j + 1][1] = (dp[x][i + j + 1][1] + tmp[i][1] * dp[v][j][1]) % mod;
27
            }
28
        }
29
        sze[x] += sze[v];
30
    }
31
}
32
long long ksm(long long a, long long mi){
33
    long long res = 1, base = a;
34
    while(mi){
35
        if(mi & 1) res = res * base % mod;
36
        mi >>= 1;
37
        base = base * base % mod;
38
    }
39
    return res;
40
}
41
int main(){
42
    int u, v;
43
    scanf("%d%d", &n, &k);
44
    e.resize(n + 1);
45
    for(int i = 1; i < n; i++){
46
        scanf("%d%d", &u, &v);
47
        e[u].push_back(v);
48
        e[v].push_back(u);
49
    }
50
    dfs(1, 0);
51
    long long ans;
52
    ans = dp[1][k][1] * ksm(n, k - 1) % mod;
53
    printf("%lld\n", ans);
54
    return 0;
55
}