Montgomery Reduction 算法流程与实际实现

Montgomery Reduction 算法流程与实际实现

下面默认对于模数 \(m\) 取模,由于这篇文章的重点是实现(其实就是我自己存一下板子),因此没有证明。

使用注意:

Montgomery Reduction 相较于 Barret Reduction 来说,不需要使用 __int128。

但是有着更高的封装程度,因为涉及到普通数与 Montgomery Reduction 运算中间量的转化

另外,常见的 Montgomery Reduction 在编程竞赛中的应用 要求模数为奇数

但是在 Min25 博客上来看,Montgomery 似乎有着更高的效率

在工程领域, Montgomery 用于处理大二进制数的取模问题。

Montgomery Reduction算法思想简介

在计算取模运算的过程中,将每一个元素 \(T\) 都乘上一个特定的值 \(R(R>m,\gcd(R,m)=1)\)

用特殊的方法处理相乘时除掉一个 \(R\) 的过程,从而避免取模运算。

在使用的模数为常量时,编译器通常会自动加入 Barrett reduction 的优化,因此实际上这个算法对于动态模数的情形更为适用。

(你自己真不一定写得过STL,但是确实可以比STL块)


编程上的应用简介

对于 \(m\) 为奇数的情况,取 \(R=2^{32}\),用 自然溢出来代替取模/位运算位移代替除法 来加速运算。

我们还需要令 \(m' = -m^{-1} \mod R\),有结论:

对于某一个数 \(T,0 \leq T < mR\),若令 \(U = Tm’ \mod R\),则 \(\frac{T+Um}{R}\) 为整数,且 \(\frac{T+Um}{R}=TR^{-1} \mod m\)

那么我们在计算 \(\frac{T}{R}\) 时,实际上只需要计算 \(\frac{T+Um}{R}\),可以预处理 \(m'\),溢出计算 \(Tm'\),位运算左移计算 \(\frac{T+Um}{R}\)

实际使用时的实现,可以用一个类实现以下方法

在实现时需要尤其注意不要出现溢出

1. 预处理\(m'\)

\((R-\lfloor \frac{R}{m}\rfloor )\cdot (R\mod m)\)

1
2
3
4
5
6
7
8
9
10
using u32=unsigned;
using i32=int;
using u64=unsigned long long;
using i64=long long;
// inv=m'
u32 m;
u32 getinv(){
u32 inv=m;
for(int i=0;i<4;++i) inv*=2-inv*m;
}

2. reduce方法

1
2
3
4
5
u32 reduce(u64 x) {
u32 y = u32(x >> 32) - u32((u64(u32(x)*inv)*m) >> 32);
// 先取u32(x)得到x mod R ,然后再转成u64进行乘法
return i32(y) < 0 ? y + m : y;
}

3. 普通数转Montgomery Reduction

我们要计算\(x\rightarrow xR=x\cdot 2^{32}\),但是如果直接用取模就失去了意义。。。

方法是快速计算\(x\cdot R^2\),然后reduce一次

1
2
3
4
u32 R2=-u64(m)%m;
u32 intToMont(i32 x){
return reduce(u64(x)*R2);
}

\[ \ \]

4. Montomery运算

1
2
3
4
5
6
7
8
9
10
11
u32 Add(u32 x,u32 y) {
x+=y-m;
return i32(x<0)?x+m:x;
}
u32 Dec(u32 x,u32 y){
x-=y;
return i32(x<0)?x+m:x;
}
u32 Mul(u32 x,u32 y){
return reduce(u64(x)*y);
}

\[ \ \]

5. Montomery Reduction转普通数

1
2
3
i32 get(u32 x){
return reduce(x);
}

封装之后,得到板子一号,这个是动态模数的。。。

实现上可能的误区:

为什么不用 -inv? 避免加法,原因是加法取模要和 m 比较

同样的,下面的 i32(y)<0 语句可以被替换为 y>=m (负数溢出),看似减少一次类型转换,但是实际上0作为常量比较快得多

加法运算时也是类似的原因,x>=m 比较实在太慢,因此强制减去一个 m,然后和 0 比

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
using u32=uint32_t;
using i32=int32_t;
using u64=uint64_t;
using i64=int64_t;

static u32 m,inv,r2,P;
u32 getinv(){
u32 inv=m;
for(int i=0;i<4;++i) inv*=2-inv*m;
return inv;
}
struct Mont{
private :
u32 x;
public :
static u32 reduce(u64 x){
u32 y=u32(x>>32)-u32((u64(u32(x)*inv)*m)>>32);
return i32(y)<0?y+m:y;
}
Mont(){ ; }
Mont(i32 x):x(reduce(u64(x)*r2)) { }
Mont& operator += (const Mont &rhs) { return x+=rhs.x-m,is32(x)<0&&(x+=m),*this; }
Mont& operator -= (const Mont &rhs) { return x-=rhs.x,i32(x)<0&&(x+=m),*this; }
Mont& operator *= (const Mont &rhs) { return x=reduce(u64(x)*rhs.x),*this; }
friend Mont operator + (Mont x,const Mont &y) { return x+=y; }
friend Mont operator - (Mont x,const Mont &y) { return x-=y; }
friend Mont operator * (Mont x,const Mont &y) { return x*=y; }
i32 get(){ return reduce(x); }
};
void Init(int m) {
::m=m;
inv=-getinv();
r2=-u64(m)%m;
}

动态模数的方法,计算 \(5\cdot 10^7!\mod 998244353\) 在 duck.ac 上评测结果,时间单位是微秒\(\mu s\)

1
2
Naive Mod     : 213689172  Time: 518352
My Montgomery : 213689172 Time: 192195



这个是我自己写的静态模数的,因为模数是静态的,所以不需要一定和0比较大小

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
template <uint32_t m> struct Mont{
private :
using u32=uint32_t;
using i32=int32_t;
using u64=uint64_t;
using i64=int64_t;
static constexpr u32 getinv(){
u32 inv=m;
for(int i=0;i<4;++i) inv*=2-inv*m;
return inv;
}
static constexpr u32 inv=-getinv(),r2=-u64(m)%m;
u32 x;
public :
static constexpr u32 reduce(u64 x){
u32 y=(x+u64(u32(x)*inv)*m)>>32;
return y>=m?y-m:y;
}
Mont(){ ; }
constexpr Mont(i32 x):x(reduce(u64(x)*r2)) { }
constexpr Mont& operator += (const Mont &rhs) { return x+=rhs.x-m,x>=m&&(x+=m),*this; }
constexpr Mont& operator -= (const Mont &rhs) { return x-=rhs.x,x>=m&&(x+=m),*this; }
constexpr Mont& operator *= (const Mont &rhs) { return x=reduce(u64(x)*rhs.x),*this; }
constexpr friend Mont operator + (Mont x,const Mont &y) { return x+=y; }
constexpr friend Mont operator - (Mont x,const Mont &y) { return x-=y; }
constexpr friend Mont operator * (Mont x,const Mont &y) { return x*=y; }
constexpr i32 get(){ return reduce(x); }
} ;

这个是摘自 LOJ多项式乘法 hly1204的提交记录

个人解读:实际上每次存储的是 \(x \mod 2m\) 的值,避免了reduce时的加减取模。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
// from https://min-25.hatenablog.com/entry/2017/08/20/171214
template <std::uint32_t P> struct MontgomeryModInt32 {
public:
using i32 = std::int32_t;
using u32 = std::uint32_t;
using i64 = std::int64_t;
using u64 = std::uint64_t;

private:
u32 v;

static constexpr u32 get_r() {
u32 iv = P;
for (u32 i = 0; i != 4; ++i) iv *= 2 - P * iv;
return iv;
}

static constexpr u32 r = -get_r(), r2 = -u64(P) % P;

static_assert((P & 1) == 1);
static_assert(r * P == -1);
static_assert(P < (1 << 30));

public:
static constexpr u32 pow_mod(u32 x, u64 y) {
if ((y %= P - 1) < 0) y += P - 1;
u32 res = 1;
for (; y != 0; y >>= 1, x = u64(x) * x % P)
if (y & 1) res = u64(res) * x % P;
return res;
}

static constexpr u32 get_pr() {
u32 tmp[32] = {}, cnt = 0;
const u64 phi = P - 1;
u64 m = phi;
for (u64 i = 2; i * i <= m; ++i) {
if (m % i == 0) {
tmp[cnt++] = i;
while (m % i == 0) m /= i;
}
}
if (m > 1) tmp[cnt++] = m;
for (u64 res = 2; res <= phi; ++res) {
bool flag = true;
for (u32 i = 0; i != cnt && flag; ++i) flag &= pow_mod(res, phi / tmp[i]) != 1;
if (flag) return res;
}
return 0;
}

MontgomeryModInt32() = default;
~MontgomeryModInt32() = default;
constexpr MontgomeryModInt32(u32 v) : v(reduce(u64(v) * r2)) {}
constexpr MontgomeryModInt32(const MontgomeryModInt32 &rhs) : v(rhs.v) {}
static constexpr u32 reduce(u64 x) { return x + (u64(u32(x) * r) * P) >> 32; }
constexpr u32 get() const {
u32 res = reduce(v);
return res - (P & -(res >= P));
}
explicit constexpr operator u32() const { return get(); }
explicit constexpr operator i32() const { return i32(get()); }
constexpr MontgomeryModInt32 &operator=(const MontgomeryModInt32 &rhs) {
return v = rhs.v, *this;
}
constexpr MontgomeryModInt32 operator-() const {
MontgomeryModInt32 res;
return res.v = (P << 1 & -(v != 0)) - v, res;
}
constexpr MontgomeryModInt32 inv() const { return pow(-1); }
constexpr MontgomeryModInt32 &operator+=(const MontgomeryModInt32 &rhs) {
return v += rhs.v - (P << 1), v += P << 1 & -(i32(v) < 0), *this;
}
constexpr MontgomeryModInt32 &operator-=(const MontgomeryModInt32 &rhs) {
return v -= rhs.v, v += P << 1 & -(i32(v) < 0), *this;
}
constexpr MontgomeryModInt32 &operator*=(const MontgomeryModInt32 &rhs) {
return v = reduce(u64(v) * rhs.v), *this;
}
constexpr MontgomeryModInt32 &operator/=(const MontgomeryModInt32 &rhs) {
return this->operator*=(rhs.inv());
}
friend MontgomeryModInt32 operator+(const MontgomeryModInt32 &lhs,
const MontgomeryModInt32 &rhs) {
return MontgomeryModInt32(lhs) += rhs;
}
friend MontgomeryModInt32 operator-(const MontgomeryModInt32 &lhs,
const MontgomeryModInt32 &rhs) {
return MontgomeryModInt32(lhs) -= rhs;
}
friend MontgomeryModInt32 operator*(const MontgomeryModInt32 &lhs,
const MontgomeryModInt32 &rhs) {
return MontgomeryModInt32(lhs) *= rhs;
}
friend MontgomeryModInt32 operator/(const MontgomeryModInt32 &lhs,
const MontgomeryModInt32 &rhs) {
return MontgomeryModInt32(lhs) /= rhs;
}
friend std::istream &operator>>(std::istream &is, MontgomeryModInt32 &rhs) {
return is >> rhs.v, rhs.v = reduce(u64(rhs.v) * r2), is;
}
friend std::ostream &operator<<(std::ostream &os, const MontgomeryModInt32 &rhs) {
return os << rhs.get();
}
constexpr MontgomeryModInt32 pow(i64 y) const {
if ((y %= P - 1) < 0) y += P - 1; // phi(P) = P - 1, assume P is a prime number
MontgomeryModInt32 res(1), x(*this);
for (; y != 0; y >>= 1, x *= x)
if (y & 1) res *= x;
return res;
}
};

这个是计算 \(5\cdot 10^7!\mod 998244353\) 在 duck.ac 上的测试结果。

1
2
3
Naive Mod      : 213689172  Time: 180649
My Montgomery : 213689172 Time: 178217
His Montgomery : 213689172 Time: 152847

这个是计算 \(7\cdot 10^7!\mod 998244353\) 在 duck.ac 上的测试结果。

1
2
3
Naive Mod      : 939830261  Time: 252908
My Montgomery : 939830261 Time: 249476
His Montgomery : 939830261 Time: 213986

还可以看Min25博客里下面的ModInt64板本

传送门

下面自己实现的\(\mod 2m\) 版本,差不多也是最终版本了,跑起来和hly1204差不多

静态版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
template <uint32_t m> struct Mont2{
private :
using u32=uint32_t;
using i32=int32_t;
using u64=uint64_t;
using i64=int64_t;
static constexpr u32 m2=m<<1;
static constexpr u32 getinv(){
u32 inv=m;
for(int i=0;i<4;++i) inv*=2-inv*m;
return inv;
}
static constexpr u32 inv=-getinv(),r2=-u64(m)%m;
u32 x;
public :
static constexpr u32 reduce(u64 x){
return (x+u64(u32(x)*inv)*m)>>32;
}
Mont2(){ ; }
constexpr Mont2(i32 x):x(reduce(u64(x)*r2)) { }
constexpr Mont2& operator += (const Mont2 &rhs) { return x+=rhs.x-m2,x>=m2&&(x+=m2),*this; }
constexpr Mont2& operator -= (const Mont2 &rhs) { return x-=rhs.x,x>=m2&&(x+=m2),*this; }
constexpr Mont2& operator *= (const Mont2 &rhs) { return x=reduce(u64(x)*rhs.x),*this; }
constexpr friend Mont2 operator + (Mont2 x,const Mont2 &y) { return x+=y; }
constexpr friend Mont2 operator - (Mont2 x,const Mont2 &y) { return x-=y; }
constexpr friend Mont2 operator * (Mont2 x,const Mont2 &y) { return x*=y; }
constexpr i32 get(){
u32 res=reduce(x);
return res>=m?res-m:res;
}
} ;

板子各有优劣.jpg

另外这是 Int_To_Montgomery 加法的速度,\(7\cdot 10^7\) 次加法与类型转换。

1
2
3
4
Naive :        : 305907824 80074
My Montgomery : 305907824 109479
My Montgomery2 : 305907824 99896
His Montgomery : 305907824 117449

动态版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
using u32=uint32_t;
using i32=int32_t;
using u64=uint64_t;
using i64=int64_t;

static u32 m,m2,inv,r2,P;
u32 getinv(){
u32 inv=m;
for(int i=0;i<4;++i) inv*=2-inv*m;
return inv;
}
struct Mont{
private :
u32 x;
public :
static u32 reduce(u64 x){
u32 y=(x+u64(u32(x)*inv)*m)>>32;
return i32(y)<0?y+m:y;
}
Mont(){ ; }
Mont(i32 x):x(reduce(u64(x)*r2)) { }
Mont& operator += (const Mont &rhs) { return x+=rhs.x-m2,i32(x)<0&&(x+=m2),*this; }
Mont& operator -= (const Mont &rhs) { return x-=rhs.x,i32(x)<0&&(x+=m2),*this; }
Mont& operator *= (const Mont &rhs) { return x=reduce(u64(x)*rhs.x),*this; }
friend Mont operator + (Mont x,const Mont &y) { return x+=y; }
friend Mont operator - (Mont x,const Mont &y) { return x-=y; }
friend Mont operator * (Mont x,const Mont &y) { return x*=y; }
i32 get(){
u32 res=reduce(x);
return res>=m?res-m:res;
}
};
void Init(int m) {
::m=m,m2=m*2;
inv=-getinv();
r2=-u64(m)%m;
}

这个动态模板计算 \(5\cdot 10^7!\mod 998244353\)

1
2
Naive Mod      : 213689172 494061 (稍微修改了一下暴力的细节。。)
My Montgomery2 : 213689172 152849

不得不说 duck.ac 真的很 nb。