0%

3045. 统计前后缀下标对 II

3045.统计前后缀下标对II

题目描述:

给你一个下标从 0 开始的字符串数组 words

定义一个 布尔 函数 isPrefixAndSuffix ,它接受两个字符串参数 str1str2

当 $str1$ 同时是 $str2$ 的前缀(prefix)和后缀(suffix)时, $isPrefixAndSuffix(str1, str2)$ 返回 $true$ ,否则返回 $false$ 。

例如,isPrefixAndSuffix("aba", "ababa") 返回 true,因为 "aba" 既是 "ababa" 的前缀,也是 "ababa" 的后缀,但是 isPrefixAndSuffix("abc", "abcd") 返回false

以整数形式,返回满足 i < jisPrefixAndSuffix(words[i], words[j])true 的下标对 (i, j)数量

数据范围:

$1\le words.len \le 10^5, \sum words[i].len \le 5\times 10^5$

$1\le words[i].len \le 10^5$

题解:

总体思路就是从后往前,将每一个字符串相同的前后缀 $s$ 处理出来,然后存到 $hash$ 表中,然后每次遇到 $words[i]$ ,计数 $mp[words[i]]$ 。

关键是怎么快速求出一个字符串所有的前后缀相同的字符串。

两种方法,一种是字符串 hash,而且 $map$ 只需要存 $hash$ 值。

一种是使用 $Z$ 函数,求出 $Z$ 函数之后,倒着遍历,可以每次往前缀加字符,但是 $map$ 存字符串的话,内存会炸,还是需要hash。

能否不使用hash?可以直接使用 $trie$ 把字符串前缀存起来,后面在里面查找出现次数。但是需要注意的是,插入 $trie$ 时应该从上次的节点继续插入,不能从头开始,从头开始太慢了。

如何判断一个字符串 $s$ 是另一个字符串 $t$ 的前后缀。

  • 字符串 $hash$

  • $Z$ 函数, $Z[i] = len - i$

  • 分解成 $pair$

    将字符串 $s$ 分解为 $[(s[0], s[n - 1]), (s[1], s[n - 2]), \cdots , (s[n - 1],s[0])]$

    将字符串 $t$ 分解为 $[(t[0], t[n - 1]),(t[1], t[n - 2]), \cdots, (t[n - 1], t[0])]$

    如果 $s$ 是 $t$ 的前后缀,必须满足 $s$ 的分解序列是 $t$ 分解序列的前缀,可以使用字典树解决。(不能用 $vector$ 了,得用 $unordered\_map$ )

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
auto optimize_cpp_stdio = []()
{
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cout.tie(nullptr);
return 0;
}();
class Solution
{
public:
const static int maxn = 1e5 + 10;
const static int maxm = 1e5 + 10;
const long long INF_LL = 0x3f3f3f3f3f3f3f3f;
const int INF = 0x3f3f3f3f;
long long mod = 1e9 + 7;
int base = 131;
long long pow[maxn];
long long hash[maxn];
void init()
{
srand(time(0));
mod = 998244353 + rand() % 10007;
base = 33 + rand() % 233;
}
void hashString(string &s)
{
pow[0] = 1;
int len = s.length();
for (int i = 0; i < len; ++i)
{
pow[i + 1] = pow[i] * base % mod;
hash[i + 1] = (hash[i] * base + s[i] - '0') % mod;
}
}
long long getHash(int l, int r)
{
// hash[r + 1] - hash[l + 1 - 1] * pow[r - l + 1];
return (hash[r + 1] - hash[l] * pow[r - l + 1] % mod + mod) % mod;
}
long long countPrefixSuffixPairs(vector<string> &words)
{
int n = words.size();
unordered_map<long long, long long> mp;
long long ans = 0;
for (int i = n - 1; i >= 0; --i)
{
int len = words[i].length();
hashString(words[i]);
long long h = getHash(0, len - 1);
ans += mp[h];
for (int j = 1; j <= len; ++j)
{
long long h1 = getHash(0, j - 1);
long long h2 = getHash(len - j, len - 1);
if (h1 == h2)
mp[h1]++;
}
}
return ans;
}
};
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
auto optimize_cpp_stdio = []()
{
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cout.tie(nullptr);
return 0;
}();
class Solution
{
public:
const static int maxn = 1e5 + 10;
const static int maxm = 1e5 + 10;
const long long INF_LL = 0x3f3f3f3f3f3f3f3f;
const int INF = 0x3f3f3f3f;
// 不需要求 Z[0],Z[0] = len
int Z[maxn];
void zFunction(const char *str, int len)
{
Z[0] = len;
for (int i = 1, l = 0, r = 0; i < len; ++i)
{
if (i <= r && Z[i - l] < r - i + 1)
Z[i] = Z[i - l];
else
{
Z[i] = max(0, r - i + 1);
while (i + Z[i] < len && str[Z[i]] == str[i + Z[i]]) // 对多余的需要暴力求出
++Z[i];
}
if (i + Z[i] - 1 > r)
l = i, r = i + Z[i] - 1;
}
}
long long countPrefixSuffixPairs(vector<string> &words)
{
int n = words.size();
unordered_map<string, long long> mp;
long long ans = 0;
for (int i = n - 1; i >= 0; --i)
{
int len = words[i].length();
zFunction(words[i].c_str(), len);
ans += mp[words[i]];
string tmp;
for (int j = len - 1; j >= 0; --j)
{
tmp += words[i][len - 1 - j];
if (Z[j] == len - j)
{
mp[tmp]++;
}
}
}
return ans;
}
};
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
auto optimize_cpp_stdio = []()
{
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cout.tie(nullptr);
return 0;
}();
class Solution
{
public:
const static int maxn = 1e5 + 10;
const static int maxm = 1e5 + 10;
const long long INF_LL = 0x3f3f3f3f3f3f3f3f;
const int INF = 0x3f3f3f3f;
long long mod = 1e9 + 7;
int base = 1331;
// 不需要求 Z[0],Z[0] = len
int Z[maxn];
void zFunction(const char *str, int len)
{
Z[0] = len;
for (int i = 1, l = 0, r = 0; i < len; ++i)
{
if (i <= r && Z[i - l] < r - i + 1)
Z[i] = Z[i - l];
else
{
Z[i] = max(0, r - i + 1);
while (i + Z[i] < len && str[Z[i]] == str[i + Z[i]]) // 对多余的需要暴力求出
++Z[i];
}
if (i + Z[i] - 1 > r)
l = i, r = i + Z[i] - 1;
}
}
long long countPrefixSuffixPairs(vector<string> &words)
{
int n = words.size();
unordered_map<long long, long long> mp;
long long ans = 0;
for (int i = n - 1; i >= 0; --i)
{
int len = words[i].length();
zFunction(words[i].c_str(), len);
long long tmp = 0;
for (int j = 0; j < len; ++j)
{
tmp = (tmp * base + words[i][j]) % mod;
}
ans += mp[tmp];
tmp = 0;
for (int j = len - 1; j >= 0; --j)
{
tmp = (tmp * base + words[i][len - 1 - j]) % mod;
if (Z[j] == len - j)
{
mp[tmp]++;
}
}
}
return ans;
}
};
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
auto optimize_cpp_stdio = []()
{
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cout.tie(nullptr);
return 0;
}();

struct TrieNode
{
const static int M = 26;
vector<TrieNode *> next;
bool exist;
int val;
TrieNode() : next(M, nullptr)
{
exist = false;
val = 0;
}
};
struct Trie
{
TrieNode *root;
Trie()
{
root = new TrieNode();
}
void insert(string &s)
{
TrieNode *cur = root;
for (int i = 0; i < s.length(); ++i)
{
int ch = s[i] - 'a';
if (!cur->next[ch])
cur->next[ch] = new TrieNode();
cur = cur->next[ch];
// cur->val++;
}
cur->exist = true;
cur->val++;
}
TrieNode *insert(string &s, TrieNode *cur)
{
for (int i = 0; i < s.length(); ++i)
{
int ch = s[i] - 'a';
if (!cur->next[ch])
cur->next[ch] = new TrieNode();
cur = cur->next[ch];
// cur->val++;
}
cur->exist = true;
cur->val++;
return cur;
}
bool exist(string &s)
{
TrieNode *cur = root;
for (int i = 0; i < s.size(); ++i)
{
int ch = s[i] - 'a';
if (!cur->next[ch])
return false;
cur = cur->next[ch];
}
return cur->exist;
}
int query(string &s)
{
TrieNode *cur = root;
for (int i = 0; i < s.size(); ++i)
{
int ch = s[i] - 'a';
if (!cur->next[ch])
return false;
cur = cur->next[ch];
}
return cur->val;
}
};
class Solution
{
public:
const static int maxn = 1e5 + 10;
const static int maxm = 1e5 + 10;
const long long INF_LL = 0x3f3f3f3f3f3f3f3f;
const int INF = 0x3f3f3f3f;
long long mod = 1e9 + 7;
int base = 1331;
// 不需要求 Z[0],Z[0] = len
int Z[maxn];
void zFunction(const char *str, int len)
{
Z[0] = len;
for (int i = 1, l = 0, r = 0; i < len; ++i)
{
if (i <= r && Z[i - l] < r - i + 1)
Z[i] = Z[i - l];
else
{
Z[i] = max(0, r - i + 1);
while (i + Z[i] < len && str[Z[i]] == str[i + Z[i]]) // 对多余的需要暴力求出
++Z[i];
}
if (i + Z[i] - 1 > r)
l = i, r = i + Z[i] - 1;
}
}
long long countPrefixSuffixPairs(vector<string> &words)
{
int n = words.size();
Trie trie;
long long ans = 0;
for (int i = n - 1; i >= 0; --i)
{
int len = words[i].length();
zFunction(words[i].c_str(), len);
ans += trie.query(words[i]);
string tmp;
TrieNode *cur = trie.root;
for (int j = len - 1; j >= 0; --j)
{
tmp.push_back(words[i][len - 1 - j]);
if (Z[j] == len - j)
{
cur = trie.insert(tmp, cur);
tmp.clear();
}
}
}
return ans;
}
};
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
auto optimize_cpp_stdio = []()
{
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cout.tie(nullptr);
return 0;
}();

struct TrieNode
{
const static int M = 26;
unordered_map<int, TrieNode *> next;
bool exist;
int val;
TrieNode()
{
exist = false;
val = 0;
}
};
struct Trie
{
TrieNode *root;
Trie()
{
root = new TrieNode();
}
void insert(string &s)
{
TrieNode *cur = root;
for (int i = 0; i < s.length(); ++i)
{
int ch = s[i] - 'a';
if (!cur->next[ch])
cur->next[ch] = new TrieNode();
cur = cur->next[ch];
// cur->val++;
}
cur->exist = true;
cur->val++;
}
TrieNode *insert(vector<int> &s, TrieNode *cur)
{
for (int i = 0; i < s.size(); ++i)
{
int ch = s[i] - 'a';
if (!cur->next[ch])
cur->next[ch] = new TrieNode();
cur = cur->next[ch];
// cur->val++;
}
cur->exist = true;
cur->val++;
return cur;
}
int query(vector<int> &s)
{
TrieNode *cur = root;
for (int i = 0; i < s.size(); ++i)
{
int ch = s[i] - 'a';
if (!cur->next[ch])
return false;
cur = cur->next[ch];
}
return cur->val;
}
};
class Solution
{
public:
const static int maxn = 1e5 + 10;
const static int maxm = 1e5 + 10;
const long long INF_LL = 0x3f3f3f3f3f3f3f3f;
const int INF = 0x3f3f3f3f;
long long mod = 1e9 + 7;
int base = 1331;
// 不需要求 Z[0],Z[0] = len
int Z[maxn];
void zFunction(const char *str, int len)
{
Z[0] = len;
for (int i = 1, l = 0, r = 0; i < len; ++i)
{
if (i <= r && Z[i - l] < r - i + 1)
Z[i] = Z[i - l];
else
{
Z[i] = max(0, r - i + 1);
while (i + Z[i] < len && str[Z[i]] == str[i + Z[i]]) // 对多余的需要暴力求出
++Z[i];
}
if (i + Z[i] - 1 > r)
l = i, r = i + Z[i] - 1;
}
}
int getPairHash(char a, char b)
{
return (a - 'a') << 5 | (b - 'a');
}
long long countPrefixSuffixPairs(vector<string> &words)
{
int n = words.size();
Trie trie;
long long ans = 0;
for (int i = n - 1; i >= 0; --i)
{
int len = words[i].length();
zFunction(words[i].c_str(), len);
vector<int> tmp;
for (int j = 0; j < len; ++j)
{
tmp.emplace_back(getPairHash(words[i][j], words[i][len - 1 - j]));
}
ans += trie.query(tmp);
TrieNode *cur = trie.root;
tmp.clear();
for (int j = len - 1; j >= 0; --j)
{
tmp.emplace_back(getPairHash(words[i][len - 1 - j], words[i][j]));
if (Z[j] == len - j)
{
cur = trie.insert(tmp, cur);
tmp.clear();
}
}
}
return ans;
}
};