連結:https://tioj.ck.tp.edu.tw/problems/1306
題目大意
有一個字串 ,詢問多個字串 分別在 中出現幾次。
題解
很久以前寫過這題,用的是雜湊,不過看起來WA了之後這題就被擱置了,無聊又點開這題,意外地找到當初WA的BUG… 給大家猜猜看:
1 2 3 4 5 6 7
| ll pw(ll a,ll e) { if(e==0)return 1; ll t = pow( a*a%MOD , e/2); if( e&1 )return a*t%MOD; return t; }
|
字串的題目可以考慮使用雜湊的技巧來把題目水過,以這題來說,就可以用Rolling hash直接比對,不過實作上時限稍緊,要把細節處理好才能AC,比較要注意的重點是處理相減的取模時,用 (A-B+M)%M
比 C=A-B; if(C<0)C+=M
慢上至少兩倍,再雜湊中,這常用的細節會大幅度的影響執行時間,要十分注意。
除了雜湊之外,這也是經典的AC自動機要解決的問題,AC自動機可以在 找出每個 在 中出現幾次,不過OJ開的記憶體有點小,而AC自動機空間時間常數偏大,而且Trie的節點會造成大量閒置空間,會很浪費記憶體。個人是調參數加上靜態分配記憶體硬過的,直接開優化模板應該不會那麼卡。
下面的AC自動機是評演算法直覺裸刻上的版本,通常這東西會細節優化好放在模板中直接用。AC自動機大致上的步驟就是:
- 建立Trie
- 蓋出fail邊 (build)
- 開始走路 (eval)
- 把走路的獲得的資料蒐集起來 (calc)
這裡calc的實作是把所有node拓譜排序後,再把答案DP回來
AC Code
AC 自動機
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
| #include<bits/stdc++.h> using namespace std;
struct node{ int next[26]; int fail; int tag; int hit; };
int nwid; vector<node> buf(310000); int newnode() { nwid++; if( nwid == buf.size() ) while(1); memset(&buf[nwid], 0,sizeof(buf[0])); return nwid; } using pnode = int; #define nullptr 0
pnode root = nullptr; vector<int> query; map<int,int> ans; map<string,int> sid;
void build() { queue<pnode> qu; qu.push(root);
buf[root].fail = root;
while(!qu.empty()) { pnode ptr = qu.front(); qu.pop(); int i=-1; for(auto e:buf[ptr].next) { i++; if(!e) continue; qu.push(e); auto tmp = buf[ptr].fail; while( tmp!=root && !buf[tmp].next[i] ) tmp=buf[tmp].fail; if( ptr!=root && buf[tmp].next[i] ) tmp = buf[tmp].next[i]; buf[e].fail = tmp; } } }
void eval(const string &s) { pnode ptr = root; for(int c:s) { c-='a'; while( ptr!=root && !buf[ptr].next[c] ) ptr = buf[ptr].fail; if( buf[ptr].next[c]) ptr = buf[ptr].next[c]; buf[ptr].hit++; } }
map<pnode,int> deg; void dfs(pnode r) { deg[r]; deg[buf[r].fail]++; for(auto e:buf[r].next) if(e) dfs(e); } void calc() { deg.clear(); ans.clear(); dfs(root); deg[root]--; queue<pnode> qu; for(auto p:deg) if(p.second==0) qu.push(p.first); while( !qu.empty() ) { auto ptr = qu.front(); qu.pop(); if(buf[ptr].tag) ans[buf[ptr].tag]=buf[ptr].hit; deg[buf[ptr].fail]--; if( deg[buf[ptr].fail] == 0 ) qu.push(buf[ptr].fail); buf[buf[ptr].fail].hit += buf[ptr].hit; } }
int main() { ios::sync_with_stdio(false); cin.tie(0);
int T,N; string tmpl, str; cin>>T; while(T--) { query.clear(); sid.clear(); nwid=0; root = newnode(); cin>>tmpl>>N; for(int i=0;i<N;++i) { cin>>str; if( sid[str] == 0 ) sid[str] = sid.size(); query.push_back(sid[str]);
auto ptr = root; for(char c:str) { if( buf[ptr].next[c-'a'] == nullptr ) buf[ptr].next[c-'a'] = newnode(); ptr = buf[ptr].next[c-'a']; } buf[ptr].tag = query.back(); } build(); eval(tmpl); calc(); for(int i:query)cout<<ans[i]<<'\n'; } }
|
Rolling Hash
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
| #include<iostream> #include<cstdio> #include<algorithm> #include<vector> #include<cstring> using namespace std;
typedef long long ll; #define P1 26 #define MOD 1000000009
ll pw1[10005]; void init() { pw1[0]=1; for(int i=1;i<10005;++i) { pw1[i]=pw1[i-1]*P1%MOD; } }
ll ht1[10001]; char tmpl[10001];
inline void build(char *str) { ht1[0]=0; int i=1; while( str[i-1] ) { ht1[i] = (ht1[i-1]*P1+str[i-1])%MOD; ++i; } }
inline ll gethash1(int L, int R) { ll tmp = ht1[R] - ht1[L-1]*pw1[R-L+1]%MOD; if( tmp < 0 ) tmp += MOD; return tmp; }
inline ll calc( char *str, ll p ) { ll res = 0; int i=0; while(str[i]) { res = ( res * p + str[i] ) % MOD; ++i; } return res; } int main() { ios::sync_with_stdio(false); cin.tie(0); int T,N; ll hA,lent,len; cin>>T; init(); while(T--) { cin>>tmpl; lent = strlen(tmpl); build(tmpl); cin>>N; while(N--) { cin>>tmpl; len = strlen(tmpl); hA = calc(tmpl,P1);
int sum = 0; int pos = 0; while( pos + len <= lent ) { if( hA == gethash1( pos+1,pos+len ) ) sum++; pos++; } cout<<sum<<'\n'; } } return 0; }
|