生活随笔
收集整理的這篇文章主要介紹了
任意模数NTT(MTT)
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
前言
眾所周知,NTT有幾個經典的模數:469762049,998244353,1004535809469762049,998244353,1004535809 4 6 9 7 6 2 0 4 9 , 9 9 8 2 4 4 3 5 3 , 1 0 0 4 5 3 5 8 0 9 為什么這些模數被稱為NTT模數呢?因為他們都是這樣一個形式: P=2a?X+1P=2^a*X+1 P = 2 a ? X + 1 為什么要有這樣一個條件呢,因為只有這樣,才能找到所需的原根 所以對于一般的一個模數P=2a?X+1P=2^a*X+1 P = 2 a ? X + 1 ,能適用的最大的多項式長度(包括結果)是2a2^a 2 a 有時候, 給出的多項式長度超過限制,我們就不能用裸的NTT了 一般有兩種情況:
模數是NTT模數,但是多項式長度略超出限制(比如模數是1004535809,輸入多項式長度和>2097152) 模數不是NTT模數,比如模數是1000000007 這個時候任意模數NTT就非常有用了
正文
我們來分析任意模數NTT做法的思路
思路一(P不是很大的時候)
根據分析,我們發現,多項式長度為N、模數為P的時候,多項式乘法的結果每一項的值0≤x≤P2N0\le x\le P^2N 0 ≤ x ≤ P 2 N 由于NTT的復雜度是Θ(nlogn)\Theta(nlogn) Θ ( n l o g n ) 的,所以nn n 的范圍可以出到10510^5 1 0 5 以上,而對于10910^9 1 0 9 級別的質數,那么結果大約是102310^{23} 1 0 2 3 級別的。如果不考慮值域,有個很好的思路是:先進行FFT,算完后取模 非常不幸的是,由于結果的值域過大,FFT的精度往往都不夠(這也是為什么要使用NTT的原因,根據實測,使用long double 的FFT,當值域≤1013\le 10^{13} ≤ 1 0 1 3 的時候,FFT是精度較好的,值域更大的時候出錯概率就會比較高了,Tip:FFT的精度并不只與值域相關,多項式長度同樣會影響精度(似乎是在Pi/n這個地方損失了精度),博主對各個長度都進行了測試,取min值 ) 寫一些上界(粗略) N=1000000N=1000000 N = 1 0 0 0 0 0 0 X=6000X=6000 X = 6 0 0 0 N=100000N=100000 N = 1 0 0 0 0 0 X=40000X=40000 X = 4 0 0 0 0 N=10000N=10000 N = 1 0 0 0 0 X=300000X=300000 X = 3 0 0 0 0 0 N=1000N=1000 N = 1 0 0 0 X=1000000X=1000000 X = 1 0 0 0 0 0 0 N=100N=100 N = 1 0 0 X=6000000X=6000000 X = 6 0 0 0 0 0 0 N=10N=10 N = 1 0 X=20000000X=20000000 X = 2 0 0 0 0 0 0 0 (精度值在1000000下為6000,對拍程序為NTT) 所以如果你發現質數不是很大,即P2N≤1013P^2N\le 10^{13} P 2 N ≤ 1 0 1 3 的時候,你可以放心的FFT(本測試的多項式長度上限為10610^6 1 0 6 ) 注意:實際對于一般的FFT,保守限制為101010^{10} 1 0 1 0 ,因為long double 可能會出現莫名的錯誤(博主太菜了,寫的代碼就出現UKE,有關using namespace std和std::的,導致精度大大下降)
思路二(基于FFT的優化)(2.6倍的普通FFT,多項式長度受限較大)
貼出一道模板題,本思路以及可以解決模板題了 洛谷模板題:任意模數NTT 我們發現FFT并不是一無是處,所以我們考慮壓縮值域 設p=?P?p=\left\lceil\sqrt P\right\rceil p = ? P ? ? 那么任何一個數都能表示成X=axp+bx(X<P,ax<p,bx<p)X=a_xp+b_x(X<P,a_x<p,b_x<p) X = a x ? p + b x ? ( X < P , a x ? < p , b x ? < p ) 的形式 那么我們考慮結果的一個值,對其進行分析 V=∑(axp+bx)?(ayp+by)=∑axayp2+(axby+bxay)p+bxby\begin{aligned} V&=\sum(a_xp+b_x)*(a_yp+b_y)\\ &=\sum a_xa_yp^2+(a_xb_y+b_xa_y)p+b_xb_y\\ \end{aligned} V ? = ∑ ( a x ? p + b x ? ) ? ( a y ? p + b y ? ) = ∑ a x ? a y ? p 2 + ( a x ? b y ? + b x ? a y ? ) p + b x ? b y ? ? 我們對每一組系數都進行計算,容易發現NN N 是10510^5 1 0 5 級別的,值域p<40000p<40000 p < 4 0 0 0 0 ,所以可以求出各項,然后再加起來 貼出AC代碼
#include <cstdio>
#include <cctype>
#include <cmath>
#include <algorithm>
using namespace std
;
namespace fast_IO
{ const int IN_LEN
= 10000000 , OUT_LEN
= 10000000 ; char ibuf
[ IN_LEN
] , obuf
[ OUT_LEN
] , * ih
= ibuf
+ IN_LEN
, * oh
= obuf
, * lastin
= ibuf
+ IN_LEN
, * lastout
= obuf
+ OUT_LEN
- 1 ; inline char getchar_ ( ) { return ( ih
== lastin
) && ( lastin
= ( ih
= ibuf
) + fread ( ibuf
, 1 , IN_LEN
, stdin ) , ih
== lastin
) ? EOF : * ih
++ ; } inline void putchar_ ( const char x
) { if ( oh
== lastout
) fwrite ( obuf
, 1 , oh
- obuf
, stdout ) , oh
= obuf
; * oh
++ = x
; } inline void flush ( ) { fwrite ( obuf
, 1 , oh
- obuf
, stdout ) ; }
}
using namespace fast_IO
;
#define getchar() getchar_()
#define putchar(x) putchar_((x))
#define rg register
typedef long long LL
;
typedef long double LD
;
#define double LD
template
< typename T
> inline T
max ( const T a
, const T b
) { return a
> b
? a
: b
; }
template
< typename T
> inline T
min ( const T a
, const T b
) { return a
< b
? a
: b
; }
template
< typename T
> inline void mind ( T
& a
, const T b
) { a
= a
< b
? a
: b
; }
template
< typename T
> inline void maxd ( T
& a
, const T b
) { a
= a
> b
? a
: b
; }
template
< typename T
> inline T
abs ( const T a
) { return a
> 0 ? a
: - a
; }
template
< typename T
> inline T
gcd ( const T a
, const T b
) { if ( ! b
) return a
; return gcd ( b
, a
% b
) ; }
template
< typename T
> inline T
lcm ( const T a
, const T b
) { return a
/ gcd ( a
, b
) * b
; }
template
< typename T
> inline T
square ( const T x
) { return x
* x
; } ;
template
< typename T
> inline void read ( T
& x
)
{ char cu
= getchar ( ) ; x
= 0 ; bool fla
= 0 ; while ( ! isdigit ( cu
) ) { if ( cu
== '-' ) fla
= 1 ; cu
= getchar ( ) ; } while ( isdigit ( cu
) ) x
= x
* 10 + cu
- '0' , cu
= getchar ( ) ; if ( fla
) x
= - x
;
}
template
< typename T
> inline void printe ( const T x
)
{ if ( x
>= 10 ) printe ( x
/ 10 ) ; putchar ( x
% 10 + '0' ) ;
}
template
< typename T
> inline void print ( const T x
)
{ if ( x
< 0 ) putchar ( '-' ) , printe ( - x
) ; else printe ( x
) ;
}
const int maxn
= 262145 ; const double PI
= acos ( ( LD
) - 1.0 ) ;
int n
, m
;
struct complex
{ double x
, y
; inline complex operator
+ ( const complex b
) const { return ( complex
) { x
+ b
. x
, y
+ b
. y
} ; } inline complex operator
- ( const complex b
) const { return ( complex
) { x
- b
. x
, y
- b
. y
} ; } inline complex operator
* ( const complex b
) const { return ( complex
) { x
* b
. x
- y
* b
. y
, x
* b
. y
+ y
* b
. x
} ; }
} ax
[ maxn
] , ay
[ maxn
] , bx
[ maxn
] , by
[ maxn
] ;
int lenth
= 1 , Reverse
[ maxn
] ;
inline void init ( const int x
)
{ rg
int tim
= 0 ; while ( lenth
<= x
) lenth
<<= 1 , tim
++ ; for ( rg
int i
= 0 ; i
< lenth
; i
++ ) Reverse
[ i
] = ( Reverse
[ i
>> 1 ] >> 1 ) | ( ( i
& 1 ) << ( tim
- 1 ) ) ;
}
inline void FFT ( complex
* A
, const int fla
)
{ for ( rg
int i
= 0 ; i
< lenth
; i
++ ) if ( i
< Reverse
[ i
] ) swap ( A
[ i
] , A
[ Reverse
[ i
] ] ) ; for ( rg
int i
= 1 ; i
< lenth
; i
<<= 1 ) { const complex w
= ( complex
) { cos ( PI
/ i
) , fla
* sin ( PI
/ i
) } ; for ( rg
int j
= 0 ; j
< lenth
; j
+ = ( i
<< 1 ) ) { complex K
= ( complex
) { 1 , 0 } ; for ( rg
int k
= 0 ; k
< i
; k
++ , K
= K
* w
) { const complex x
= A
[ j
+ k
] , y
= A
[ j
+ k
+ i
] * K
; A
[ j
+ k
] = x
+ y
; A
[ j
+ k
+ i
] = x
- y
; } } }
}
int P
, p
;
int main ( )
{ read ( n
) , read ( m
) , read ( P
) ; p
= 31624 ; init ( n
+ m
) ; for ( rg
int i
= 0 ; i
<= n
; i
++ ) { int x
; read ( x
) ; ax
[ i
] . x
= x
/ p
, bx
[ i
] . x
= x
% p
; } for ( rg
int i
= 0 ; i
<= m
; i
++ ) { int x
; read ( x
) ; ay
[ i
] . x
= x
/ p
, by
[ i
] . x
= x
% p
; } FFT ( ax
, 1 ) , FFT ( bx
, 1 ) , FFT ( ay
, 1 ) , FFT ( by
, 1 ) ; for ( rg
int i
= 0 ; i
< lenth
; i
++ ) { const complex A
= ax
[ i
] , B
= bx
[ i
] , C
= ay
[ i
] , D
= by
[ i
] ; ax
[ i
] = A
* C
, ay
[ i
] = B
* D
; bx
[ i
] = A
* D
, by
[ i
] = B
* C
; } FFT ( ax
, - 1 ) , FFT ( bx
, - 1 ) , FFT ( ay
, - 1 ) , FFT ( by
, - 1 ) ; for ( rg
int i
= 0 ; i
<= n
+ m
; i
++ ) { const LL A
= ax
[ i
] . x
/ lenth
+ 0.5 , B
= ay
[ i
] . x
/ lenth
+ 0.5 , C
= bx
[ i
] . x
/ lenth
+ 0.5 , D
= by
[ i
] . x
/ lenth
+ 0.5 ; print ( ( A
% P
* p
% P
* p
% P
+ B
% P
+ ( C
% P
+ D
% P
) * p
% P
) % P
) , putchar ( ' ' ) ; } return flush ( ) , 0 ;
}
這里又出現了UKE!!!博主太菜啦 如果p的賦值寫成?P?\left\lceil\sqrt P\right\rceil ? P ? ? ,就會會在洛谷上WA兩個點 如果發現我哪里寫掛了,請速聯系我! 效率分析,一次一般的多項式乘法共調用3次FFT函數,這里調用了8次,所以這種任意模數NTT算法常數大概是2.6左右 update by 2019.1.7:可以通過一些技巧減小精度損失以支持更多位數或在當前位數下只使用double(tip by yx2003) 詳細方法:將FFT函數中的K直接預處理即可,減少乘法中的精度損失,對多項式長度較長(100000及以上) 的情況有較大優化效果 為什么呢?大概這個精度是受限兩個方面:一個是值域上限的限制(在多項式長度較小,值域較大時體現),一個是多項式長度的限制(在多項式長度較大時體現)。容易發現多項式長的時候精度受限在單位根上,這個優化就是針對單位根精度的優化 提升效果: N=100000N=100000 N = 1 0 0 0 0 0 X=40000?100000X=40000\Rightarrow100000 X = 4 0 0 0 0 ? 1 0 0 0 0 0
N=1000000N=1000000 N = 1 0 0 0 0 0 0 X=6000?30000X=6000\Rightarrow30000 X = 6 0 0 0 ? 3 0 0 0 0 對于模板題的速度也能有較大提升,大約用時是原來的12\frac12 2 1 ? 代碼
#include <cstdio>
#include <cctype>
#include <cmath>
#include <algorithm>
using namespace std
;
namespace fast_IO
{ const int IN_LEN
= 10000000 , OUT_LEN
= 10000000 ; char ibuf
[ IN_LEN
] , obuf
[ OUT_LEN
] , * ih
= ibuf
+ IN_LEN
, * oh
= obuf
, * lastin
= ibuf
+ IN_LEN
, * lastout
= obuf
+ OUT_LEN
- 1 ; inline char getchar_ ( ) { return ( ih
== lastin
) && ( lastin
= ( ih
= ibuf
) + fread ( ibuf
, 1 , IN_LEN
, stdin ) , ih
== lastin
) ? EOF : * ih
++ ; } inline void putchar_ ( const char x
) { if ( oh
== lastout
) fwrite ( obuf
, 1 , oh
- obuf
, stdout ) , oh
= obuf
; * oh
++ = x
; } inline void flush ( ) { fwrite ( obuf
, 1 , oh
- obuf
, stdout ) ; }
}
using namespace fast_IO
;
#define getchar() getchar_()
#define putchar(x) putchar_((x))
#define rg register
typedef long long LL
;
template
< typename T
> inline T
max ( const T a
, const T b
) { return a
> b
? a
: b
; }
template
< typename T
> inline T
min ( const T a
, const T b
) { return a
< b
? a
: b
; }
template
< typename T
> inline void mind ( T
& a
, const T b
) { a
= a
< b
? a
: b
; }
template
< typename T
> inline void maxd ( T
& a
, const T b
) { a
= a
> b
? a
: b
; }
template
< typename T
> inline T
abs ( const T a
) { return a
> 0 ? a
: - a
; }
template
< typename T
> inline T
gcd ( const T a
, const T b
) { if ( ! b
) return a
; return gcd ( b
, a
% b
) ; }
template
< typename T
> inline T
lcm ( const T a
, const T b
) { return a
/ gcd ( a
, b
) * b
; }
template
< typename T
> inline T
square ( const T x
) { return x
* x
; } ;
template
< typename T
> inline void read ( T
& x
)
{ char cu
= getchar ( ) ; x
= 0 ; bool fla
= 0 ; while ( ! isdigit ( cu
) ) { if ( cu
== '-' ) fla
= 1 ; cu
= getchar ( ) ; } while ( isdigit ( cu
) ) x
= x
* 10 + cu
- '0' , cu
= getchar ( ) ; if ( fla
) x
= - x
;
}
template
< typename T
> inline void printe ( const T x
)
{ if ( x
>= 10 ) printe ( x
/ 10 ) ; putchar ( x
% 10 + '0' ) ;
}
template
< typename T
> inline void print ( const T x
)
{ if ( x
< 0 ) putchar ( '-' ) , printe ( - x
) ; else printe ( x
) ;
}
const int maxn
= 262145 ; const double PI
= acos ( ( double ) - 1.0 ) ;
int n
, m
;
struct complex
{ double x
, y
; inline complex operator
+ ( const complex b
) const { return ( complex
) { x
+ b
. x
, y
+ b
. y
} ; } inline complex operator
- ( const complex b
) const { return ( complex
) { x
- b
. x
, y
- b
. y
} ; } inline complex operator
* ( const complex b
) const { return ( complex
) { x
* b
. x
- y
* b
. y
, x
* b
. y
+ y
* b
. x
} ; }
} ax
[ maxn
] , ay
[ maxn
] , bx
[ maxn
] , by
[ maxn
] ;
int lenth
= 1 , Reverse
[ maxn
] ;
complex w
[ maxn
] ;
complex fw
[ maxn
] ;
inline void init ( const int x
)
{ rg
int tim
= 0 ; while ( lenth
<= x
) lenth
<<= 1 , tim
++ ; for ( rg
int i
= 0 ; i
< lenth
; i
++ ) Reverse
[ i
] = ( Reverse
[ i
>> 1 ] >> 1 ) | ( ( i
& 1 ) << ( tim
- 1 ) ) , w
[ i
] = ( complex
) { cos ( i
* PI
/ lenth
) , sin ( i
* PI
/ lenth
) } , fw
[ i
] = ( complex
) { cos ( i
* PI
/ lenth
) , - sin ( i
* PI
/ lenth
) } ;
}
complex W
[ maxn
] ;
inline void FFT ( complex
* A
, const int fla
)
{ for ( rg
int i
= 0 ; i
< lenth
; i
++ ) if ( i
< Reverse
[ i
] ) swap ( A
[ i
] , A
[ Reverse
[ i
] ] ) ; for ( rg
int i
= 1 ; i
< lenth
; i
<<= 1 ) { if ( fla
== 1 ) { for ( rg
int k
= 0 ; k
< i
; k
++ ) W
[ k
] = w
[ lenth
/ i
* k
] ; } else { for ( rg
int k
= 0 ; k
< i
; k
++ ) W
[ k
] = fw
[ lenth
/ i
* k
] ; } for ( rg
int j
= 0 ; j
< lenth
; j
+ = ( i
<< 1 ) ) { for ( rg
int k
= 0 ; k
< i
; k
++ ) { const complex x
= A
[ j
+ k
] , y
= W
[ k
] * A
[ j
+ k
+ i
] ; A
[ j
+ k
] = x
+ y
; A
[ j
+ k
+ i
] = x
- y
; } } }
}
int P
, p
;
int main ( )
{ read ( n
) , read ( m
) , read ( P
) ; p
= 31624 ; init ( n
+ m
) ; for ( rg
int i
= 0 ; i
<= n
; i
++ ) { int x
; read ( x
) ; ax
[ i
] . x
= x
/ p
, bx
[ i
] . x
= x
% p
; } for ( rg
int i
= 0 ; i
<= m
; i
++ ) { int x
; read ( x
) ; ay
[ i
] . x
= x
/ p
, by
[ i
] . x
= x
% p
; } FFT ( ax
, 1 ) , FFT ( bx
, 1 ) , FFT ( ay
, 1 ) , FFT ( by
, 1 ) ; for ( rg
int i
= 0 ; i
< lenth
; i
++ ) { const complex A
= ax
[ i
] , B
= bx
[ i
] , C
= ay
[ i
] , D
= by
[ i
] ; ax
[ i
] = A
* C
, ay
[ i
] = B
* D
; bx
[ i
] = A
* D
, by
[ i
] = B
* C
; } FFT ( ax
, - 1 ) , FFT ( bx
, - 1 ) , FFT ( ay
, - 1 ) , FFT ( by
, - 1 ) ; for ( rg
int i
= 0 ; i
<= n
+ m
; i
++ ) { const LL A
= ax
[ i
] . x
/ lenth
+ 0.5 , B
= ay
[ i
] . x
/ lenth
+ 0.5 , C
= bx
[ i
] . x
/ lenth
+ 0.5 , D
= by
[ i
] . x
/ lenth
+ 0.5 ; print ( ( A
% P
* p
% P
* p
% P
+ B
% P
+ ( C
% P
+ D
% P
) * p
% P
) % P
) , putchar ( ' ' ) ; } return flush ( ) , 0 ;
}
思路三(基于NTT的優化)
經過前面的分析,我們得知:FFT的運算結果≤NP2\le NP^2 ≤ N P 2 ,是102310^{23} 1 0 2 3 級別的 我們現在換一個思路,我們選出一些NTT模數(質數)(乘積大于FFT結果的最大值),求出在這些模意義下的值分別數多少,最后通過中國剩余定理(CRT)算出在給定模數的模意義下的值(選的質數一般是:469762049,998244353,1004535809469762049,998244353,1004535809 4 6 9 7 6 2 0 4 9 , 9 9 8 2 4 4 3 5 3 , 1 0 0 4 5 3 5 8 0 9 ) 但是我們發現所有質數的乘積爆long long了,所以不能直接CRT 設一個數的值為AnsAns A n s ,選取的三個質數分別為p1,p2,p3p_1,p_2,p_3 p 1 ? , p 2 ? , p 3 ? 我們通過6次DFT,3次IDFT算出在模意義下的值 Ans≡a1(modp1),Ans≡a2(modp2),Ans≡a3(modp3)Ans\equiv a_1\pmod {p_1},Ans\equiv a_2\pmod {p_2},Ans\equiv a_3\pmod {p_3} A n s ≡ a 1 ? ( m o d p 1 ? ) , A n s ≡ a 2 ? ( m o d p 2 ? ) , A n s ≡ a 3 ? ( m o d p 3 ? ) 根據中國剩余定理我們可以算出Ans=a4(modp1p2)Ans=a_4\pmod{p_1p_2} A n s = a 4 ? ( m o d p 1 ? p 2 ? ) 設Ans=a5p1p2+a4Ans=a_5p_1p_2+a_4 A n s = a 5 ? p 1 ? p 2 ? + a 4 ? ,我們已知a4a_4 a 4 ? ,如果能求出a5a_5 a 5 ? 就能求出Ans的值 我們發現因為Ans≡a3(modp3)Ans\equiv a_3\pmod {p_3} A n s ≡ a 3 ? ( m o d p 3 ? ) 所以a5p1p2≡a3?a4(modp3)a_5p_1p_2\equiv a_3-a_4\pmod {p_3} a 5 ? p 1 ? p 2 ? ≡ a 3 ? ? a 4 ? ( m o d p 3 ? ) 就能推出a5≡(a3?a4)p1?1p2?1(modp3)a_5\equiv (a_3-a_4)p_1^{-1}p_2^{-1}\pmod {p_3} a 5 ? ≡ ( a 3 ? ? a 4 ? ) p 1 ? 1 ? p 2 ? 1 ? ( m o d p 3 ? ) 然后直接計算就好了 代碼也非常好寫(這份代碼很不注重常數,只注重好寫)
#include <cstdio>
#include <cctype>
#include <cstring>
#include <cmath>
namespace fast_IO
{ const int IN_LEN
= 10000000 , OUT_LEN
= 10000000 ; char ibuf
[ IN_LEN
] , obuf
[ OUT_LEN
] , * ih
= ibuf
+ IN_LEN
, * oh
= obuf
, * lastin
= ibuf
+ IN_LEN
, * lastout
= obuf
+ OUT_LEN
- 1 ; inline char getchar_ ( ) { return ( ih
== lastin
) && ( lastin
= ( ih
= ibuf
) + fread ( ibuf
, 1 , IN_LEN
, stdin ) , ih
== lastin
) ? EOF : * ih
++ ; } inline void putchar_ ( const char x
) { if ( oh
== lastout
) fwrite ( obuf
, 1 , oh
- obuf
, stdout ) , oh
= obuf
; * oh
++ = x
; } inline void flush ( ) { fwrite ( obuf
, 1 , oh
- obuf
, stdout ) ; }
}
using namespace fast_IO
;
#define getchar() getchar_()
#define putchar(x) putchar_((x))
typedef long long LL
;
#define rg register
template
< typename T
> inline T
max ( const T a
, const T b
) { return a
> b
? a
: b
; }
template
< typename T
> inline T
min ( const T a
, const T b
) { return a
< b
? a
: b
; }
template
< typename T
> inline T
mind ( T
& a
, const T b
) { a
= a
< b
? a
: b
; }
template
< typename T
> inline T
maxd ( T
& a
, const T b
) { a
= a
> b
? a
: b
; }
template
< typename T
> inline T
abs ( const T a
) { return a
> 0 ? a
: - a
; }
template
< typename T
> inline void swap ( T
& a
, T
& b
) { T c
= a
; a
= b
; b
= c
; }
template
< typename T
> inline void swap ( T
* a
, T
* b
) { T c
= a
; a
= b
; b
= c
; }
template
< typename T
> inline T
gcd ( const T a
, const T b
) { if ( ! b
) return a
; return gcd ( b
, a
% b
) ; }
template
< typename T
> inline T
square ( const T x
) { return x
* x
; } ;
template
< typename T
> inline void read ( T
& x
)
{ char cu
= getchar ( ) ; x
= 0 ; bool fla
= 0 ; while ( ! isdigit ( cu
) ) { if ( cu
== '-' ) fla
= 1 ; cu
= getchar ( ) ; } while ( isdigit ( cu
) ) x
= x
* 10 + cu
- '0' , cu
= getchar ( ) ; if ( fla
) x
= - x
;
}
template
< typename T
> void printe ( const T x
)
{ if ( x
>= 10 ) printe ( x
/ 10 ) ; putchar ( x
% 10 + '0' ) ;
}
template
< typename T
> inline void print ( const T x
)
{ if ( x
< 0 ) putchar ( '-' ) , printe ( - x
) ; else printe ( x
) ;
}
const int maxn
= 262145 ;
int n
, m
;
struct Ntt
{ LL mod
, a
[ maxn
] , b
[ maxn
] ; ; inline LL
pow ( LL x
, LL y
) { rg LL res
= 1 ; for ( ; y
; y
>>= 1 , x
= x
* x
% mod
) if ( y
& 1 ) res
= res
* x
% mod
; return res
; } int lenth
, Reverse
[ maxn
] ; inline void init ( const int x
) { rg
int tim
= 0 ; lenth
= 1 ; while ( lenth
<= x
) lenth
<<= 1 , tim
++ ; for ( rg
int i
= 0 ; i
< lenth
; i
++ ) Reverse
[ i
] = ( Reverse
[ i
>> 1 ] >> 1 ) | ( ( i
& 1 ) << ( tim
- 1 ) ) ; } inline void NTT ( LL
* A
, const int fla
) { for ( rg
int i
= 0 ; i
< lenth
; i
++ ) if ( i
< Reverse
[ i
] ) swap ( A
[ i
] , A
[ Reverse
[ i
] ] ) ; for ( rg
int i
= 1 ; i
< lenth
; i
<<= 1 ) { LL w
= pow ( 3 , ( mod
- 1 ) / i
/ 2 ) ; if ( fla
== - 1 ) w
= pow ( w
, mod
- 2 ) ; for ( rg
int j
= 0 ; j
< lenth
; j
+ = ( i
<< 1 ) ) { LL K
= 1 ; for ( rg
int k
= 0 ; k
< i
; k
++ , K
= K
* w
% mod
) { const LL x
= A
[ j
+ k
] , y
= A
[ j
+ k
+ i
] * K
% mod
; A
[ j
+ k
] = ( x
+ y
) % mod
; A
[ j
+ k
+ i
] = ( mod
+ x
- y
) % mod
; } } } if ( fla
== - 1 ) { const int inv
= pow ( lenth
, mod
- 2 ) ; for ( rg
int i
= 0 ; i
< lenth
; i
++ ) A
[ i
] = A
[ i
] * inv
% mod
; } }
} Q
[ 3 ] ;
LL
EXgcd ( const LL a
, const LL b
, LL
& x
, LL
& y
)
{ if ( ! b
) { x
= 1 , y
= 0 ; return a
; } const LL res
= EXgcd ( b
, a
% b
, y
, x
) ; y
- = a
/ b
* x
; return res
;
}
inline LL
msc ( LL a
, LL b
, LL mod
)
{ LL v
= ( a
* b
- ( LL
) ( ( long double ) a
/ mod
* b
+ 1e-8 ) * mod
) ; return v
< 0 ? v
+ mod
: v
;
}
int N
, a
[ 3 ] , p
[ 3 ] ;
LL
CRT ( )
{ LL P
= 1 , sum
= 0 ; for ( rg
int i
= 1 ; i
<= N
; i
++ ) P
* = p
[ i
] ; for ( rg
int i
= 1 ; i
<= N
; i
++ ) { const LL m
= P
/ p
[ i
] ; LL x
, y
; EXgcd ( p
[ i
] , m
, x
, y
) ; sum
= ( sum
+ msc ( msc ( y
, m
, P
) , a
[ i
] , P
) ) % P
; } return sum
;
}
int P
;
int main ( )
{ read ( n
) , read ( m
) , read ( P
) ; Q
[ 0 ] . mod
= 469762049 , Q
[ 0 ] . init ( n
+ m
) ; Q
[ 1 ] . mod
= 998244353 , Q
[ 1 ] . init ( n
+ m
) ; Q
[ 2 ] . mod
= 1004535809 , Q
[ 2 ] . init ( n
+ m
) ; for ( rg
int i
= 0 ; i
<= n
; i
++ ) read ( Q
[ 0 ] . a
[ i
] ) , Q
[ 2 ] . a
[ i
] = Q
[ 1 ] . a
[ i
] = Q
[ 0 ] . a
[ i
] ; for ( rg
int i
= 0 ; i
<= m
; i
++ ) read ( Q
[ 0 ] . b
[ i
] ) , Q
[ 2 ] . b
[ i
] = Q
[ 1 ] . b
[ i
] = Q
[ 0 ] . b
[ i
] ; Q
[ 0 ] . NTT ( Q
[ 0 ] . a
, 1 ) , Q
[ 0 ] . NTT ( Q
[ 0 ] . b
, 1 ) ; Q
[ 1 ] . NTT ( Q
[ 1 ] . a
, 1 ) , Q
[ 1 ] . NTT ( Q
[ 1 ] . b
, 1 ) ; Q
[ 2 ] . NTT ( Q
[ 2 ] . a
, 1 ) , Q
[ 2 ] . NTT ( Q
[ 2 ] . b
, 1 ) ; for ( rg
int i
= 0 ; i
< Q
[ 0 ] . lenth
; i
++ ) Q
[ 0 ] . a
[ i
] = ( LL
) Q
[ 0 ] . a
[ i
] * Q
[ 0 ] . b
[ i
] % Q
[ 0 ] . mod
, Q
[ 1 ] . a
[ i
] = ( LL
) Q
[ 1 ] . a
[ i
] * Q
[ 1 ] . b
[ i
] % Q
[ 1 ] . mod
, Q
[ 2 ] . a
[ i
] = ( LL
) Q
[ 2 ] . a
[ i
] * Q
[ 2 ] . b
[ i
] % Q
[ 2 ] . mod
; Q
[ 0 ] . NTT ( Q
[ 0 ] . a
, - 1 ) ; Q
[ 1 ] . NTT ( Q
[ 1 ] . a
, - 1 ) ; Q
[ 2 ] . NTT ( Q
[ 2 ] . a
, - 1 ) ; N
= 2 , p
[ 1 ] = Q
[ 0 ] . mod
, p
[ 2 ] = Q
[ 1 ] . mod
; const int INV
= Q
[ 2 ] . pow ( Q
[ 0 ] . mod
, Q
[ 2 ] . mod
- 2 ) * Q
[ 2 ] . pow ( Q
[ 1 ] . mod
, Q
[ 2 ] . mod
- 2 ) % Q
[ 2 ] . mod
; for ( rg
int i
= 0 ; i
<= n
+ m
; i
++ ) { a
[ 1 ] = Q
[ 0 ] . a
[ i
] , a
[ 2 ] = Q
[ 1 ] . a
[ i
] ; const LL ans1
= CRT ( ) ; const LL ans2
= ( ( Q
[ 2 ] . a
[ i
] - ans1
) % Q
[ 2 ] . mod
+ Q
[ 2 ] . mod
) % Q
[ 2 ] . mod
* INV
% Q
[ 2 ] . mod
; print ( ( ans2
* Q
[ 0 ] . mod
% P
* Q
[ 1 ] . mod
% P
+ ans1
) % P
) , putchar ( ' ' ) ; } return flush ( ) , 0 ;
}
思路四(思路二的進階版)
容易發現,我們可以把那個pp p 的次數為11 1 的項直接合并 這樣就可以從調用8次DFT優化到調用7次 這里我就不另貼代碼了 (另外,優化到這里,這個算法的速度依然很慢,中國剩余定理常數過大) 另外還可以用多項式的奇技淫巧優化常數 資料參考:毛嘯,IOI2016國家集訓隊論文《再探快速傅里葉變換》 貼出myy給的代碼鏈接 由于7次的DFT/IDFT已經很快了,所以咕咕咕咕咕咕 以后有空閑時間再更吧 現在先貼個鏈接,是txc寫的任意模數 NTT 和 DFT 的優化學習筆記
總結
大概是比較清真的算法,如果推出來就很好記
超強干貨來襲 云風專訪:近40年碼齡,通宵達旦的技術人生
總結
以上是生活随笔 為你收集整理的任意模数NTT(MTT) 的全部內容,希望文章能夠幫你解決所遇到的問題。
如果覺得生活随笔 網站內容還不錯,歡迎將生活随笔 推薦給好友。