在實作一些演算法,如 Miller-Rabin 時,會出現 long long 整數相乘再取餘數的步驟。

1
2
long long a, b, c;
auto res = a * b % c;

然而若 a*b 的結果大於 long long 的表達範圍的話會溢位,便無法取得正確的結果。因此會需要自行實做乘法來避免溢位。具體方法可以參考維基百科,簡單來說就是模擬直式乘法:如果每次最多乘以 2,那麼一個 63 bit 的整數可以在 64 bit 的空間內計算出來不會發生溢位。因此實作上用 uint64_t 來儲存即可。

不過絕大多數的 OJ 與競賽編譯器都使用 GCC,而 GCC 帶有非標準的 128 bit 整數 __int128 🔗,如果環境允許使用的話,那麼就只需要把 long long轉型到__int128運算即可。

很久以前有人說__int128運算上很慢,因此似乎讓後者的實做不是特別的流行。既然最近剛好寫到了一題,就來做個實驗。

就直白說結論了:__int128 比自行用迴圈實作乘法快了 30~50 倍 (O3 優化)。

因此如果環境許可的話,就用 __int128 ,程式碼既簡單又快速。具體的 __int128 組語要再來研究看看了。

方法 107 次執行時間
int128 48 ms
Morris 的實作 🔗 2066 ms
Wiki 的實作 1245 ms

其實原本想拿日月卦長的實作,不過該實作有上面提到實作細節的 bug ,而且跟 wiki 差不多,就沒放上了。

實驗程式碼

  • 實驗環境
    • 編譯器:gcc version 9.3.0 (Ubuntu 9.3.0-17ubuntu1~20.04)
    • 編譯參數:g++ -O3
    • CPU: AMD-3700X
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#include <bits/stdc++.h>
using namespace std;

using ll = long long;

inline ll mul_int128(ll a, ll b, ll m)
{
ll res = 0;
__int128 _a = a;
__int128 _b = b;
_a = (_a * _b) % m;
return _a;
}

long long mul_morris(unsigned long long a, unsigned long long b, unsigned long long mod) {
long long ret = 0;
for (a %= mod, b %= mod; b != 0; b >>= 1, a <<= 1, a = a >= mod ? a - mod : a) {
if (b&1) {
ret += a;
if (ret >= mod) ret -= mod;
}
}
return ret;
}

uint64_t mul_wiki(uint64_t a, uint64_t b, uint64_t m)
{
uint64_t d = 0, mp2 = m >> 1;
int i;
if (a >= m) a %= m;
if (b >= m) b %= m;
for (i = 0; i < 64; ++i)
{
d = (d > mp2) ? (d << 1) - m : d << 1;
if (a & 0x8000000000000000ULL)
d += b;
if (d > m) d -= m;
a <<= 1;
}
return d%m;
}

#define Test(Expr) \
[]() \
{ \
std::mt19937_64 mt(7122); \
ll hashv = 0; \
auto TimeStart = std::chrono::high_resolution_clock::now(); \
for (int i = 0; i < 10; ++i) \
for (int j = 0; j < 1000000; ++j) \
{ \
ll a = mt() % LLONG_MAX; \
ll b = mt() % LLONG_MAX; \
ll c = 1 + mt() % (LLONG_MAX - 1); \
hashv ^= (Expr); \
} \
auto TimeEnd = std::chrono::high_resolution_clock::now(); \
auto ms = std::chrono::duration_cast<std::chrono::microseconds>(TimeEnd - TimeStart); \
return std::make_tuple(ms.count(), hashv); \
}();

int main()
{
int Trys = 10;
const ll EmptyHash = 6862552923007731049LL;
const ll AcceptHash = 8476752227605395455LL;
ll TotalRunTime, RunTime;
ll HashVal = 0;

ll EmptyRunTime = 0;
ll Int128RunTime = 0;
ll MorrisMulRunTime = 0;
ll WikiMulRunTime = 0;

cout << "Test Empty loop..." << endl;
TotalRunTime = 0;
for (int i = 1; i <= Trys; ++i)
{
tie(RunTime, HashVal) = Test(a^b^c);
TotalRunTime += RunTime;
cout << setw(2) << i << ". " << RunTime << " microsecond" << endl;
assert(HashVal == EmptyHash);
}
EmptyRunTime = TotalRunTime / Trys;

cout << "Test mul : __int128..." << endl;
TotalRunTime = 0;
for (int i = 1; i <= Trys; ++i)
{
tie(RunTime, HashVal) = Test(mul_int128(a,b,c));
TotalRunTime += RunTime;
cout << setw(2) << i << ". " << RunTime << " microsecond" << endl;
assert(HashVal == AcceptHash);
}
Int128RunTime = TotalRunTime / Trys - EmptyRunTime;

cout << "Test mul : morris" << endl;
TotalRunTime = 0;
for (int i = 1; i <= Trys; ++i)
{
tie(RunTime, HashVal) = Test(mul_morris(a,b,c));
TotalRunTime += RunTime;
cout << setw(2) << i << ". " << RunTime << " microsecond" << endl;
assert(HashVal == AcceptHash);
}
MorrisMulRunTime = TotalRunTime / Trys - EmptyRunTime;

cout << "Test mul : wiki" << endl;
TotalRunTime = 0;
for (int i = 1; i <= Trys; ++i)
{
tie(RunTime, HashVal) = Test(mul_wiki(a,b,c));
TotalRunTime += RunTime;
cout << setw(2) << i << ". " << RunTime << " microsecond" << endl;
assert(HashVal == AcceptHash);
}
WikiMulRunTime = TotalRunTime / Trys - EmptyRunTime;

cout << "======================" << endl;
cout << "Empty : " << EmptyRunTime << "\tmicrosecond" << endl;
cout << "======================" << endl;
cout << "Int128: " << Int128RunTime << "\tmicrosecond" << endl;
cout << "Morris: " << MorrisMulRunTime << "\tmicrosecond" << endl;
cout << "Wiki : " << WikiMulRunTime << "\tmicrosecond" << endl;
cout << "======================" << endl;

}

程式輸出

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
Test Empty loop...
1. 204376 microsecond
2. 195937 microsecond
3. 198422 microsecond
4. 193477 microsecond
5. 196865 microsecond
6. 199258 microsecond
7. 196088 microsecond
8. 203152 microsecond
9. 194675 microsecond
10. 193845 microsecond
Test mul : __int128...
1. 242874 microsecond
2. 246191 microsecond
3. 246739 microsecond
4. 244880 microsecond
5. 246286 microsecond
6. 245972 microsecond
7. 246628 microsecond
8. 245438 microsecond
9. 247223 microsecond
10. 245661 microsecond
Test mul : morris
1. 2265340 microsecond
2. 2271909 microsecond
3. 2326091 microsecond
4. 2247727 microsecond
5. 2277070 microsecond
6. 2240657 microsecond
7. 2235139 microsecond
8. 2277157 microsecond
9. 2250817 microsecond
10. 2247773 microsecond
Test mul : wiki
1. 1435613 microsecond
2. 1434507 microsecond
3. 1451556 microsecond
4. 1440528 microsecond
5. 1442246 microsecond
6. 1439456 microsecond
7. 1442855 microsecond
8. 1469818 microsecond
9. 1428127 microsecond
10. 1445635 microsecond
======================
Empty : 197609 microsecond
======================
Int128: 48180 microsecond
Morris: 2066359 microsecond
Wiki : 1245425 microsecond
======================

反組譯參考

很明顯的 __int128 的輸出短很多,而且都是簡易運算指令。自行實作的乘法有 jmp 等跳轉,明顯複雜。