快速数论变换

快速数论变换(FNT)是环 $\mathbb{Z}/ m \mathbb{Z}$ 上的 Fourier 变换(FFT)。
至于 快速 Fourier 变换是怎样的有什么用处,这里就不多说了,可参考这里

NFT 的核心问题

无论是 FNT 还是 FFT 其本质其关键就是寻找一个 $w$ 使得 $w^{2^n} = 1$。在复数域中这个问题是显然的,而在一个环那就不那么简单了,这里我们考虑环 $R = \mathbb{Z}/ m \mathbb{Z}$。$R$ 为域当且仅当 $m$ 为素数。我们的问题是

  1. 我们选取 $R$ 中找比较大(满足我们的需求)的 $n$ 使得 $w^{2^n} = 1$ ;
  2. 找出对应的“原根“;
  3. 类似 FFT 的处理
  4. 应用时各种可能出错的情形,最常见的是溢出,还有只适用于数据范围不超过P的非负整数。

NFT 问题解决

为使得分析问题更为简单,我们考虑在$m = p$ 为素数的情形,此时,我们有 $2^n \mid p-1$ 即 $p=k 2^n+1$ 为(Fermat)素数,例如:

  1. $p=479 \times 2^{21} +1 = 1004535809,g = 3$
  2. $p= 13 \times 2^{20} + 1 = 13631489,g = 15$
  3. $p= 17 \times 2^{27} + 1 = 2281701377,g=3$
    更多常数选择可见这里

最终我们选择了 $FM = 1004535809$ 它的优势在于,它的两倍不超过 int 它的乘积不超过 long long 很有利于我们的运算,如果使用刚好不超过 long long 的数使用时很容易出现溢出并不方便。并且它恰好比较大。避免了做完 FFT 出现溢出。另外它可以取到的最大的 $N>2e6$ 也很不错。例如现在如果我们要做 $2^k,k \leq 21$ 的 NFT。那么我们取 $w = g^{\frac{p-1}{2^k}}$ 即可。

HDU 1402

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
/*------ Welcome to visit blog of dna049: http://dna049.com ------*/
const int N=132005;
char sa[N>>1],sb[N>>1];
LL a[N],b[N];
LL pow_mod(LL x,LL n,LL p){
LL r=1;
while(n){
if(n&1) r=r*x%p;
n>>=1; x=x*x%p;
}
return r;
}
void change(LL *x,int len,int loglen){
for(int i=1;i<len;++i){
int t=i,k=0;
for(int j=0;j<loglen;++j,t>>=1){
k=(k<<1)|(t&1);
}
if(k<i) swap(x[i],x[k]);
}
}
const LL FM = 479<<21|1;
void nft(LL *x,int len,int loglen,bool isInverse){
LL g = pow_mod(3,(FM-1)>>loglen,FM);
if(isInverse){
g=inv(g,FM);
LL invlen = pow_mod(len,FM-2,FM);
for(int i=0;i<len;++i){
x[i]=x[i]*invlen%FM;
}
}
change(x,len,loglen);
for(int step=2;step<=len;step<<=1){
int half = step>>1;
LL wn = pow_mod(g,len/step,FM);
for(int i=0;i<len;i+=step){
LL w = 1;
for(int j = i;j<i+half;++j){
LL t=(w*x[j+half])%FM;
x[j+half]=(x[j]-t+FM)%FM;
x[j]=(x[j]+t)%FM;
w = w*wn%FM;
}
}
}
}
int main(){
// freopen("/Users/dna049/Desktop/AC/in","r",stdin);
while(~scanf("%s%s",sa,sb)){
int alen=(int)strlen(sa);
int blen=(int)strlen(sb);
int len=1,loglen=0,tmp=alen+blen+3;
while(len<tmp){
len<<=1;++loglen;
}
clr(a,0);clr(b,0);
for(int i=0;i!=alen;++i) a[i]=sa[alen-i-1]-'0';
for(int i=0;i!=blen;++i) b[i]=sb[blen-i-1]-'0';
nft(a,len,loglen,0);
nft(b,len,loglen,0);
for(int i=0;i!=len;++i){
a[i] = a[i]*b[i]%FM;
}
nft(a,len,loglen,1);
int cnt=0;
while(cnt<len){
a[cnt+1]+=a[cnt]/10;
a[cnt]%=10;++cnt;
}
cnt=alen+blen;
while(cnt>1&&a[cnt-1]==0) --cnt;
for(int i=cnt-1;i>=0;--i){
putchar((int)a[i]+'0');
}
puts("");
}
return 0;
}