题目描述:
给你一棵 n
个节点的无向树,节点编号为 1
到 n
。给你一个整数 n
和一个长度为 n - 1
的二维整数数组 edges
,其中 edges[i] = [ui, vi]
表示节点 ui
和 vi
在树中有一条边。
请你返回树中的 合法路径数目 。
如果在节点 a
到节点 b
之间 恰好有一个 节点的编号是质数,那么我们称路径 (a, b)
是 合法的 。
注意:
- 路径
(a, b)
指的是一条从节点 a
开始到节点 b
结束的一个节点序列,序列中的节点 互不相同 ,且相邻节点之间在树上有一条边。 - 路径
(a, b)
和路径 (b, a)
视为 同一条 路径,且只计入答案 一次 。
数据范围:
$1\le n \le 10^5$
题解:
类似树上路径长度为 $k$ 的有多少条。
该题首先要筛素数,使用线性筛。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| bool notPrime[maxn]; int prime[maxn], cnt; void sieve(int n) { notPrime[1] = true; for(int i = 2; i <= n; ++i) { if(!notPrime[i]) prime[++cnt] = i; for(int j = 1; j <= cnt && i * prime[j] <= n; ++j) { notPrime[i * prime[j]] = true; if(i % prime[j] == 0) break; } } }
|
之前我的模板好像忘了把 $1$ 置为非素数了,被坑了。
然后使用树形dp,使用 $dp[u][0/1]$ 分别表示以 $u$ 为终点不包含素数或包含一个素数的路径条数。
则很容易得到
如果 $u$ 不是素数的话, $dp[u][0/1]$ 可以从 $dp[v][0/1]$ 递推过来;但是如果 $u$ 是素数的话, $dp[u][0]$ 只能为 $0$ ,因为不可能存在终点为 $u$ ,且不包含素数的,因为 $u$ 就是素数。 $dp[u][1]$ 只能从 $dp[v][0]$ 转移。
注意初始条件,初始条件为如果 $u$ 是素数,那么 $dp[u][1] = 1, dp[u][0] = 0$ ;否则 $dp[u][0] = 1, dp[u][1] = 0$ 。在前序遍历的地方初始化初始条件,在后续遍历的地方转移。
答案就是 $dp[u][0] \times dp[v][1] + dp[u][1] \times dp[v][0]$ 。可以想象成每次像列表中加入一个 $v$ 。那么需要统计前缀和,新加入一个 $v$ 时,和前面的前缀和计算对答案的贡献。
也可以使用连通分量的做法:
如果一个节点是质数,那么从他的所有的非质数孩子出发,一直沿着非质数节点走,得到连通分量。那么该节点 $u$ 的贡献就是, $size[v_1] \times size[v_2] + (size[v_1]+size[v_2]) \times size[v_3]\cdots $ ,最后需要加上从 $u$ 出发的路径数量,就是所有size
的和。
也可以使用并查集处理,将不是素数的边使用并查集链接,然后针对素数节点,遍历他的所有相邻的非素数节点,然后使用并查集得到连通块大小,然后和上一种方法一样的求解方法。
代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
| auto optimize_cpp_stdio = []() { std::ios::sync_with_stdio(false); std::cin.tie(nullptr); std::cout.tie(nullptr); return 0; }(); class Solution { public: const static int maxn = 1e5 + 10; const static int maxm = 1e5 + 10; const int INF = 0x3f3f3f3f; vector<int> prime; int cnt; vector<bool> notPrime; vector<vector<long long>> dp; vector<vector<int>> g; long long ans; void sieve(int n) { notPrime[1] = true; for (int i = 2; i <= n; i++) { if (!notPrime[i]) prime[++cnt] = i;
for (int j = 1; prime[j] * i <= n && j <= cnt; j++) { notPrime[i * prime[j]] = 1; if (i % prime[j] == 0) break; } } } void dfs(int u, int fa) { if (!notPrime[u]) { dp[u][1] = 1; } else { dp[u][0] = 1; } for (int i = 0; i < g[u].size(); ++i) { int v = g[u][i]; if (v == fa) continue; dfs(v, u); ans += dp[v][1] * dp[u][0]; ans += dp[v][0] * dp[u][1]; if (!notPrime[u]) { dp[u][1] += dp[v][0]; } else { dp[u][1] += dp[v][1]; dp[u][0] += dp[v][0]; } } } long long countPaths(int n, vector<vector<int>> &edges) { cnt = 0; prime.resize(n); notPrime.resize(n + 1, false); sieve(n); dp.resize(n + 1, vector<long long>(2, 0)); g.resize(n + 1); for (auto &edge : edges) { int u = edge[0]; int v = edge[1]; g[u].emplace_back(v); g[v].emplace_back(u); } ans = 0; dfs(1, 0); return ans; } };
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
| vector<bool> notPrime; vector<int> prime; auto seive = [](int maxx) { std::ios::sync_with_stdio(false); std::cin.tie(nullptr); std::cout.tie(nullptr); notPrime.resize(maxx + 1); notPrime[0] = notPrime[1] = true; for (int i = 2; i <= maxx; ++i) { if (!notPrime[i]) prime.emplace_back(i); for (int j = 0; j < prime.size() && prime[j] * i <= maxx; ++j) { notPrime[prime[j] * i] = true; if (i % prime[j] == 0) break; } } return 0; }(1e5); class Solution { public: const static int maxn = 1e5 + 10; const static int maxm = 1e5 + 10; const static long long mod = 1e9 + 7; const long long INF_LL = 0x3f3f3f3f3f3f3f3f; const int INF = 0x3f3f3f3f; vector<vector<int>> g; vector<vector<int>> dp; vector<int> size; long long ans = 0; void dfs(int u, int fa, vector<int> &path) { path.emplace_back(u); for (int i = 0; i < g[u].size(); ++i) { int v = g[u][i]; if (!notPrime[v] || v == fa) continue; dfs(v, u, path); } } long long countPaths(int n, vector<vector<int>> &edges) { g.resize(n + 1); dp.resize(n + 1, vector<int>(2)); size.resize(n + 1); for (auto &edge : edges) { int u = edge[0]; int v = edge[1]; g[u].emplace_back(v); g[v].emplace_back(u); } vector<int> path; path.reserve(n); long long ans = 0; for (int i = 1; i <= n; ++i) { if (notPrime[i]) continue; long long sum = 0; for (auto &v : g[i]) { if (!notPrime[v]) continue; if (size[v] == 0) { path.clear(); dfs(v, i, path); for (auto &node : path) size[node] = path.size(); } ans += size[v] * sum; sum += size[v]; } ans += sum; } return ans; } };
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
| vector<bool> notPrime; vector<int> prime; auto seive = [](int maxx) { std::ios::sync_with_stdio(false); std::cin.tie(nullptr); std::cout.tie(nullptr); notPrime.resize(maxx + 1); notPrime[0] = notPrime[1] = true; for (int i = 2; i <= maxx; ++i) { if (!notPrime[i]) prime.emplace_back(i); for (int j = 0; j < prime.size() && prime[j] * i <= maxx; ++j) { notPrime[prime[j] * i] = true; if (i % prime[j] == 0) break; } } return 0; }(1e5);
struct UF { vector<int> fa; vector<int> size; UF(int n) : fa(n + 1), size(n + 1) { for (int i = 0; i <= n; ++i) { fa[i] = i; size[i] = 1; } } int find(int u) { return fa[u] == u ? u : fa[u] = find(fa[u]); } void unite(int u, int v) { int up = find(u); int vp = find(v); if (up == vp) return; fa[up] = vp; size[vp] += size[up]; } int getSize(int u) { return size[find(u)]; } }; class Solution { public: const static int maxn = 1e5 + 10; const static int maxm = 1e5 + 10; const static long long mod = 1e9 + 7; const long long INF_LL = 0x3f3f3f3f3f3f3f3f; const int INF = 0x3f3f3f3f; vector<vector<int>> g; vector<vector<int>> dp; long long ans = 0; long long countPaths(int n, vector<vector<int>> &edges) { g.resize(n + 1); dp.resize(n + 1, vector<int>(2)); UF uf(n + 1); for (auto &edge : edges) { int u = edge[0]; int v = edge[1]; g[u].emplace_back(v); g[v].emplace_back(u); if (notPrime[u] && notPrime[v]) uf.unite(u, v); } for (int u = 1; u <= n; ++u) { if (notPrime[u]) continue; long long sum = 0; for (auto &v : g[u]) { if (!notPrime[v]) continue; int size = uf.getSize(v); ans += size * sum; sum += size; } ans += sum; } return ans; } };
|