連結:https://tioj.ck.tp.edu.tw/problems/1306

題目大意

有一個字串 T,詢問多個字串 Pi 分別在 T 中出現幾次。

題解

很久以前寫過這題,用的是雜湊,不過看起來WA了之後這題就被擱置了,無聊又點開這題,意外地找到當初WA的BUG… 給大家猜猜看:

1
2
3
4
5
6
7
ll pw(ll a,ll e)
{
if(e==0)return 1;
ll t = pow( a*a%MOD , e/2);
if( e&1 )return a*t%MOD;
return t;
}

字串的題目可以考慮使用雜湊的技巧來把題目水過,以這題來說,就可以用Rolling hash直接比對,不過實作上時限稍緊,要把細節處理好才能AC,比較要注意的重點是處理相減的取模時,用 (A-B+M)%MC=A-B; if(C<0)C+=M 慢上至少兩倍,再雜湊中,這常用的細節會大幅度的影響執行時間,要十分注意。

除了雜湊之外,這也是經典的AC自動機要解決的問題,AC自動機可以在 O(T+Pi) 找出每個 PT 中出現幾次,不過OJ開的記憶體有點小,而AC自動機空間時間常數偏大,而且Trie的節點會造成大量閒置空間,會很浪費記憶體。個人是調參數加上靜態分配記憶體硬過的,直接開優化模板應該不會那麼卡。

下面的AC自動機是評演算法直覺裸刻上的版本,通常這東西會細節優化好放在模板中直接用。AC自動機大致上的步驟就是:

  1. 建立Trie
  2. 蓋出fail邊 (build)
  3. 開始走路 (eval)
  4. 把走路的獲得的資料蒐集起來 (calc)

這裡calc的實作是把所有node拓譜排序後,再把答案DP回來

AC Code

AC 自動機

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#include<bits/stdc++.h>
using namespace std;

struct node{
int next[26];
int fail;
int tag;
int hit;
};

int nwid;
vector<node> buf(310000);
int newnode()
{
nwid++;
if( nwid == buf.size() ) while(1);
memset(&buf[nwid], 0,sizeof(buf[0]));

return nwid;
}
using pnode = int;
#define nullptr 0

pnode root = nullptr;
vector<int> query;
map<int,int> ans;
map<string,int> sid;

void build()
{
queue<pnode> qu;
qu.push(root);

buf[root].fail = root;

while(!qu.empty())
{
pnode ptr = qu.front();
qu.pop();

int i=-1;
for(auto e:buf[ptr].next)
{
i++;
if(!e) continue;
qu.push(e);
auto tmp = buf[ptr].fail;
while( tmp!=root && !buf[tmp].next[i] )
tmp=buf[tmp].fail;
if( ptr!=root && buf[tmp].next[i] )
tmp = buf[tmp].next[i];
buf[e].fail = tmp;
}
}
}

void eval(const string &s)
{
pnode ptr = root;
for(int c:s)
{
c-='a';
while( ptr!=root && !buf[ptr].next[c] )
ptr = buf[ptr].fail;
if( buf[ptr].next[c])
ptr = buf[ptr].next[c];
buf[ptr].hit++;
}
}

map<pnode,int> deg;
void dfs(pnode r)
{
deg[r];
deg[buf[r].fail]++;
for(auto e:buf[r].next)
if(e) dfs(e);
}
void calc()
{
deg.clear();
ans.clear();
dfs(root);
deg[root]--;
queue<pnode> qu;
for(auto p:deg)
if(p.second==0)
qu.push(p.first);

while( !qu.empty() )
{
auto ptr = qu.front();
qu.pop();

if(buf[ptr].tag)
ans[buf[ptr].tag]=buf[ptr].hit;

deg[buf[ptr].fail]--;
if( deg[buf[ptr].fail] == 0 )
qu.push(buf[ptr].fail);
buf[buf[ptr].fail].hit += buf[ptr].hit;
}
}

int main()
{
ios::sync_with_stdio(false);
cin.tie(0);

int T,N;
string tmpl, str;
cin>>T;
while(T--)
{
query.clear();
sid.clear();
nwid=0;
root = newnode();
cin>>tmpl>>N;

for(int i=0;i<N;++i)
{
cin>>str;
if( sid[str] == 0 )
sid[str] = sid.size();
query.push_back(sid[str]);

auto ptr = root;
for(char c:str)
{
if( buf[ptr].next[c-'a'] == nullptr )
buf[ptr].next[c-'a'] = newnode();
ptr = buf[ptr].next[c-'a'];
}
buf[ptr].tag = query.back();
}
build();
eval(tmpl);
calc();
for(int i:query)cout<<ans[i]<<'\n';
}
}

Rolling Hash

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<vector>
#include<cstring>
using namespace std;

typedef long long ll;
#define P1 26
#define MOD 1000000009

ll pw1[10005];
void init()
{
pw1[0]=1;
for(int i=1;i<10005;++i)
{
pw1[i]=pw1[i-1]*P1%MOD;
}
}

ll ht1[10001];
char tmpl[10001];

inline void build(char *str)
{
ht1[0]=0;
int i=1;
while( str[i-1] )
{
ht1[i] = (ht1[i-1]*P1+str[i-1])%MOD;
++i;
}
}

inline ll gethash1(int L, int R)
{
ll tmp = ht1[R] - ht1[L-1]*pw1[R-L+1]%MOD;
if( tmp < 0 ) tmp += MOD;
return tmp;
}

inline ll calc( char *str, ll p )
{
ll res = 0;
int i=0;
while(str[i])
{
res = ( res * p + str[i] ) % MOD;
++i;
}
return res;
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
int T,N;
ll hA,lent,len;
cin>>T;
init();
while(T--)
{
cin>>tmpl;
lent = strlen(tmpl);
build(tmpl);
cin>>N;
while(N--)
{
cin>>tmpl;
len = strlen(tmpl);
hA = calc(tmpl,P1);

int sum = 0;
int pos = 0;
while( pos + len <= lent )
{
if( hA == gethash1( pos+1,pos+len ) ) sum++;
pos++;
}
cout<<sum<<'\n';
}
}
return 0;
}