0%

2867. 统计树中的合法路径数目

2867.统计树中的合法路径数目

题目描述:

给你一棵 n 个节点的无向树,节点编号为 1n 。给你一个整数 n 和一个长度为 n - 1 的二维整数数组 edges ,其中 edges[i] = [ui, vi] 表示节点 uivi 在树中有一条边。

请你返回树中的 合法路径数目

如果在节点 a 到节点 b 之间 恰好有一个 节点的编号是质数,那么我们称路径 (a, b)合法的

注意:

  • 路径 (a, b) 指的是一条从节点 a 开始到节点 b 结束的一个节点序列,序列中的节点 互不相同 ,且相邻节点之间在树上有一条边。
  • 路径 (a, b) 和路径 (b, a) 视为 同一条 路径,且只计入答案 一次

数据范围:

$1\le n \le 10^5$

题解:

类似树上路径长度为 $k$ 的有多少条。

该题首先要筛素数,使用线性筛。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
bool notPrime[maxn];
int prime[maxn], cnt;
void sieve(int n)
{
notPrime[1] = true;
for(int i = 2; i <= n; ++i)
{
if(!notPrime[i]) prime[++cnt] = i;
for(int j = 1; j <= cnt && i * prime[j] <= n; ++j)
{
notPrime[i * prime[j]] = true;
if(i % prime[j] == 0) break;
}
}
}

之前我的模板好像忘了把 $1$ 置为非素数了,被坑了。

然后使用树形dp,使用 $dp[u][0/1]$ 分别表示以 $u$ 为终点不包含素数或包含一个素数的路径条数。

则很容易得到

如果 $u$ 不是素数的话, $dp[u][0/1]$ 可以从 $dp[v][0/1]$ 递推过来;但是如果 $u$ 是素数的话, $dp[u][0]$ 只能为 $0$ ,因为不可能存在终点为 $u$ ,且不包含素数的,因为 $u$ 就是素数。 $dp[u][1]$ 只能从 $dp[v][0]$ 转移。

注意初始条件,初始条件为如果 $u$ 是素数,那么 $dp[u][1] = 1, dp[u][0] = 0$ ;否则 $dp[u][0] = 1, dp[u][1] = 0$ 。在前序遍历的地方初始化初始条件,在后续遍历的地方转移。

答案就是 $dp[u][0] \times dp[v][1] + dp[u][1] \times dp[v][0]$ 。可以想象成每次像列表中加入一个 $v$ 。那么需要统计前缀和,新加入一个 $v$ 时,和前面的前缀和计算对答案的贡献。

也可以使用连通分量的做法:

如果一个节点是质数,那么从他的所有的非质数孩子出发,一直沿着非质数节点走,得到连通分量。那么该节点 $u$ 的贡献就是, $size[v_1] \times size[v_2] + (size[v_1]+size[v_2]) \times size[v_3]\cdots $ ,最后需要加上从 $u$ 出发的路径数量,就是所有size 的和。

也可以使用并查集处理,将不是素数的边使用并查集链接,然后针对素数节点,遍历他的所有相邻的非素数节点,然后使用并查集得到连通块大小,然后和上一种方法一样的求解方法。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
auto optimize_cpp_stdio = []()
{
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cout.tie(nullptr);
return 0;
}();
class Solution
{
public:
const static int maxn = 1e5 + 10;
const static int maxm = 1e5 + 10;
const int INF = 0x3f3f3f3f;
vector<int> prime;
int cnt;
vector<bool> notPrime; // true 不是质数,false是质数
vector<vector<long long>> dp;
vector<vector<int>> g;
long long ans;
void sieve(int n)
{
notPrime[1] = true;
for (int i = 2; i <= n; i++)
{
if (!notPrime[i])
prime[++cnt] = i;

for (int j = 1; prime[j] * i <= n && j <= cnt; j++)
{
notPrime[i * prime[j]] = 1;
if (i % prime[j] == 0)
break;
}
}
}
void dfs(int u, int fa)
{
if (!notPrime[u])
{
dp[u][1] = 1;
}
else
{
dp[u][0] = 1;
}
for (int i = 0; i < g[u].size(); ++i)
{
int v = g[u][i];
if (v == fa)
continue;
dfs(v, u);
// 以u为终点的
ans += dp[v][1] * dp[u][0];
ans += dp[v][0] * dp[u][1];
if (!notPrime[u])
{
dp[u][1] += dp[v][0];
}
else
{
dp[u][1] += dp[v][1];
dp[u][0] += dp[v][0];
}
}
}
long long countPaths(int n, vector<vector<int>> &edges)
{
cnt = 0;
prime.resize(n);
notPrime.resize(n + 1, false);
sieve(n);
dp.resize(n + 1, vector<long long>(2, 0));
g.resize(n + 1);
for (auto &edge : edges)
{
int u = edge[0];
int v = edge[1];
g[u].emplace_back(v);
g[v].emplace_back(u);
}
ans = 0;
dfs(1, 0);
return ans;
}
};
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
vector<bool> notPrime;
vector<int> prime;
auto seive = [](int maxx)
{
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cout.tie(nullptr);
notPrime.resize(maxx + 1);
notPrime[0] = notPrime[1] = true;
for (int i = 2; i <= maxx; ++i)
{
if (!notPrime[i])
prime.emplace_back(i);
for (int j = 0; j < prime.size() && prime[j] * i <= maxx; ++j)
{
notPrime[prime[j] * i] = true;
if (i % prime[j] == 0)
break;
}
}
return 0;
}(1e5);
class Solution
{
public:
const static int maxn = 1e5 + 10;
const static int maxm = 1e5 + 10;
const static long long mod = 1e9 + 7;
const long long INF_LL = 0x3f3f3f3f3f3f3f3f;
const int INF = 0x3f3f3f3f;
vector<vector<int>> g;
vector<vector<int>> dp; // dp[u][0]表示以u为终点不包含素数的路径条数,dp[u][1]表示以u为终点包含素数的路径条数。
vector<int> size;
long long ans = 0;
void dfs(int u, int fa, vector<int> &path)
{
path.emplace_back(u);
for (int i = 0; i < g[u].size(); ++i)
{
int v = g[u][i];
if (!notPrime[v] || v == fa)
continue;
dfs(v, u, path);
}
}
long long countPaths(int n, vector<vector<int>> &edges)
{
g.resize(n + 1);
dp.resize(n + 1, vector<int>(2));
size.resize(n + 1);
for (auto &edge : edges)
{
int u = edge[0];
int v = edge[1];
g[u].emplace_back(v);
g[v].emplace_back(u);
}
vector<int> path;
path.reserve(n);
long long ans = 0;
for (int i = 1; i <= n; ++i)
{
if (notPrime[i])
continue; // 不是素数
long long sum = 0;
for (auto &v : g[i])
{
if (!notPrime[v])
continue; // 是素数
if (size[v] == 0)
{
path.clear();
dfs(v, i, path);
for (auto &node : path)
size[node] = path.size();
}
ans += size[v] * sum;
sum += size[v];
}
ans += sum;
}
return ans;
}
};
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
vector<bool> notPrime;
vector<int> prime;
auto seive = [](int maxx)
{
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cout.tie(nullptr);
notPrime.resize(maxx + 1);
notPrime[0] = notPrime[1] = true;
for (int i = 2; i <= maxx; ++i)
{
if (!notPrime[i])
prime.emplace_back(i);
for (int j = 0; j < prime.size() && prime[j] * i <= maxx; ++j)
{
notPrime[prime[j] * i] = true;
if (i % prime[j] == 0)
break;
}
}
return 0;
}(1e5);

struct UF
{
vector<int> fa;
vector<int> size;
UF(int n) : fa(n + 1), size(n + 1)
{
for (int i = 0; i <= n; ++i)
{
fa[i] = i;
size[i] = 1;
}
}
int find(int u)
{
return fa[u] == u ? u : fa[u] = find(fa[u]);
}
void unite(int u, int v)
{
int up = find(u);
int vp = find(v);
if (up == vp)
return;
fa[up] = vp;
size[vp] += size[up];
}
int getSize(int u)
{
return size[find(u)];
}
};
class Solution
{
public:
const static int maxn = 1e5 + 10;
const static int maxm = 1e5 + 10;
const static long long mod = 1e9 + 7;
const long long INF_LL = 0x3f3f3f3f3f3f3f3f;
const int INF = 0x3f3f3f3f;
vector<vector<int>> g;
vector<vector<int>> dp; // dp[u][0]表示以u为终点不包含素数的路径条数,dp[u][1]表示以u为终点包含素数的路径条数。
long long ans = 0;
long long countPaths(int n, vector<vector<int>> &edges)
{
g.resize(n + 1);
dp.resize(n + 1, vector<int>(2));
UF uf(n + 1);
for (auto &edge : edges)
{
int u = edge[0];
int v = edge[1];
g[u].emplace_back(v);
g[v].emplace_back(u);
if (notPrime[u] && notPrime[v])
uf.unite(u, v);
}
for (int u = 1; u <= n; ++u)
{
if (notPrime[u])
continue;
long long sum = 0;
for (auto &v : g[u])
{
if (!notPrime[v])
continue;
int size = uf.getSize(v);
ans += size * sum;
sum += size;
}
ans += sum;
}
return ans;
}
};