精华内容
下载资源
问答
  • 遗传算法求多项式最小值 C

    千次阅读 2018-09-14 13:52:25
    问题:在下面的程序中将要运用遗传算法对一个多项式最小值: y=x^6-10x^5-26x^4+344x^3+193x^2-1846x-1680 要求在(-8,8)间寻找使表达式达到最小的x,误差为0.001。 问题分析: 编码:采用常规码,即二进制码...

    问题:

    在下面的程序中将要运用遗传算法对一个多项式求最小值:
    y=x6-10x5-26x4+344x3+193x^2-1846x-1680
    要求在(-8,8)间寻找使表达式达到最小的x,误差为0.001。

    问题分析:
    函数式变换
    将函数解析式提取公因式可加快计算速度,如下式
    f(x)=x*(x*(x*(x*(x*(x-10)-26)+344)+193)-1846)-1680

    编码
    采用常规码,即二进制码编码。构造简单,交叉、变异的实现非常容易,同时解的表达也很简洁、直观。可以每0.001取一个点,这样理论误差讲小于0.0005,可以满足题目中的误差要求。此事总的求解空间为:
    N = (8 - (-8)) * 1000 = 16000,可以用n = 14位二进制来表示。

    PS:在可行域[Umin,Umax],用长度n的二进制编码表示参数,k为二进制编码的编码精度,则
    k=(Umax-Umin)/(2^n-1)

    基本遗传算法的运行参数:
    M:群体大小,一般取20~100;
    T:遗传算法的终止进化代数,一般取100~500;
    Pc:交叉概率,一般取0.4~0.99;
    Pm:变异概率,一般取0.0001~0.1。

    群体规模m:
    群体规模m可以选择 n ~ 2n 的一个确定的数,这里选择 m = 20

    初始种群的选取:
    在这里初始种群将在值域范围内随机选取

    终止规则:
    ①最优解在连续的20次循环中改变量小于0.01,此事认为这个最优解为满足题目要求的最优解,求解成功,退出程序
    ②总的循环次数大于1200次时,循环也将结束,这种情况按照求解失败处理

    交叉规则:
    采用最常用的双亲双子法

    选择:
    在进行交叉、变异后,种群中的个体个数达到2m个,将这2m个染色体按其适应度进行排序,保留最优的m个淘汰其他的,使种群在整体上得到进化

    程序代码:

    #include <stdio.h>
    #include <math.h>
    #include <stdlib.h>
    #include <time.h>
     
    #define SUM 20            //总共的染色体数量
    #define MAXloop 1200       //最大循环次数
    #define error 0.01        //若两次最优值之差小于此数则认为结果没有改变
    #define crossp 0.7        //交叉概率
    #define mp 0.04           //变异概率
     
     
    //用于求解函数y=x^6-10x^5-26x^4+344x^3+193x^2-1846x-1680在(-8,8)之间的最小值 
     
    struct gen                //定义染色体结构
    {
    	int info;        		//染色体结构,用一整型数的后14位作为染色体编码 
    	float suitability;		//次染色体所对应的适应度函数值,在本题中为表达式的值 
    };
    struct gen gen_group[SUM];//定义一个含有20个染色体的组
    struct gen gen_new[SUM];  
     
    struct gen gen_result;    //记录最优的染色体
    int result_unchange_time; //记录在error前提下最优值为改变的循环次数
     
    struct log                //形成链表,记录每次循环所产生的最优的适应度
    {
    	float suitability;
    	struct log *next;
    }llog,*head,*end;
    int log_num;              //链表长度
     
    /**************函数声明******************/ 
    void initiate();          	//初始化函数,主要负责产生初始化种群 
    void evaluation(int flag);	//评估种群中各染色体的适应度,并据此进行排序 
    void cross();				//交叉函数 
    void selection();			//选择函数 
    int  record();				//记录每次循环产生的最优解并判断是否终止循环 
    void mutation();			//变异函数 
    void showresult(int);		//显示结果 
    //-----------------------以上函数由主函数直接调用 
    int   randsign(float p);	//按照概率p产生随机数0、1,其值为1的概率为p 
    int   randbit(int i,int j);	//产生一个在i,j两个数之间的随机整数 
    int   randnum();			//随机产生一个由14个基因组成的染色体 
    int   convertionD2B(float x);//对现实解空间的可能解x进行二进制编码(染色体形式) 
    float convertionB2D(int x);	//将二进制编码x转化为现实解空间的值 
    int   createmask(int a);	//用于交叉操作 
     
    int main()
    {
    	int i,flag;
    	flag=0;
    	initiate();				//产生初始化种群 
        evaluation( 0 );		//对初始化种群进行评估、排序 
    	for( i = 0 ; i < MAXloop ; i++ )
    	{
    		cross();			//进行交叉操作 
    		evaluation( 1 );	//对子种群进行评估、排序 
    		selection();		//对父子种群中选择最优的NUM个作为新的父种群 
    		if( record() == 1 )	//满足终止规则1,则flag=1并停止循环 
    		{
    			flag = 1;
    			break;          //跳出for循环
    		}
    		mutation();			//变异操作 
    	}
    	showresult( flag );		//按照flag显示寻优结果 
    	return 0;
    }
     
    void initiate()            //初始化种群
    {
    	int i , stime;	
    	long ltime;
    	ltime=time(NULL);     //time()获取当前系统时间
    	stime=(unsigned)ltime/2;
    	srand(stime);         //srand()设置随机种子,保证每次运行得到不同的随机数
    	for( i = 0 ; i < SUM ; i++ )
    	{
    		gen_group[i].info = randnum();		//调用randnum()函数(获取随机数)建立初始种群	 
    	}
    	gen_result.suitability=1000;
    	result_unchange_time=0;
    	head=end=(struct log *)malloc(sizeof(llog));//初始化链表 
    	if(head==NULL)
    	{
    		printf("\n内存不够!\n");
    		exit(0);
    	}
    	end->next = NULL;
    	log_num = 1;
    }
     
    void evaluation(int flag)        //评估适应度,进行排序
    {
    	int i,j;
    	struct gen *genp;
    	int gentinfo;
    	float gentsuitability;
    	float x;
    	if( flag == 0 )			// flag=0的时候对父种群进行操作 
    		genp = gen_group;
    	else genp = gen_new;    // flag=1的时候对子代种群进行操作 
    	for(i = 0 ; i < SUM ; i++)//计算各染色体对应的表达式值
    	{
    		x = convertionB2D( genp[i].info );    //解码
    		genp[i].suitability = x*(x*(x*(x*(x*(x-10)-26)+344)+193)-1846)-1680;    //提取公因式比原式更快,适应度值=函数值
    	}
    	for(i = 0 ; i < SUM - 1 ; i++)//按表达式的值进行递增排序,
    	{
    		for(j = i + 1 ; j < SUM ; j++)
    		{
    			if( genp[i].suitability > genp[j].suitability )
    			{
    				gentinfo = genp[i].info;
    				genp[i].info = genp[j].info;
    				genp[j].info = gentinfo;
    				
    				gentsuitability = genp[i].suitability;
    				genp[i].suitability = genp[j].suitability;
    				genp[j].suitability = gentsuitability;		
    			}
    		}
    	}
    }
     
    void cross()                 //交叉函数
    {
    	int i , j , k ;
    	int mask1 , mask2;
    	int a[SUM];
    	for(i = 0 ; i < SUM ; i++)  a[i] = 0;
    	k = 0;
    	for(i = 0 ; i < SUM ; i++)
    	{
    		if( a[i] == 0)
    		{
    			for( ; ; )//随机找到一组未进行过交叉的染色体与a[i]交叉
    			{
       				j = randbit(i + 1 , SUM - 1);  //在i+1和SUM-1之间产生一个随机数
    				if( a[j] == 0)	break;
    			}
    			if(randsign(crossp) == 1)		//按照crossp的概率对选择的染色体进行交叉操作,randsign(p)按概率p返回1
    			{
    				mask1 = createmask(randbit(0 , 14)); //由ranbit选择交叉位 
    				mask2 = ~mask1;				//~为位运算符,表示按位取反
    				gen_new[k].info = (gen_group[i].info) & mask1 + (gen_group[j].info) & mask2;
    				gen_new[k+1].info=(gen_group[i].info) & mask2 + (gen_group[j].info) & mask1;
    				k = k + 2;
    			}
    			else 		//不进行交叉 
    			{
    				gen_new[k].info=gen_group[i].info;
    				gen_new[k+1].info=gen_group[j].info;
    				k=k+2;
    			}
    			a[i] = a[j] = 1;     //i、j完成交叉操作的标志位
    		}
    	}
    }
     
    void selection()         //选择函数
    {
    	int i , j , k;
    	j = 0;
    	i = SUM/2-1;
    	if(gen_group[i].suitability < gen_new[i].suitability)
    	{
    		for(j = 1 ; j < SUM / 2 ; j++)
    		{
    			if(gen_group[i+j].suitability > gen_new[i-j].suitability)
    				break;
    		}
    	}
    	else
    		if(gen_group[i].suitability>gen_new[i].suitability)
    		{
    			for(j=-1;j>-SUM/2;j--)
    			{
    				if(gen_group[i+j].suitability<=gen_new[i-j].suitability)
    					break;
    			}
    		}
    	for(k=j;k<SUM/2+1;k++)      //将更优秀的子代个体替换掉原来的父代个体
    	{
    		gen_group[i+k].info = gen_new[i-k].info;
    		gen_group[i+k].suitability = gen_new[i-k].suitability;
    	}	
    }
     
    int record()	//记录最优解和判断是否满足条件 
    {
    	float x;	
    	struct log *r;
    	r=(struct log *)malloc(sizeof(llog));    //申请结点
    	if(r==NULL)
    	{
    		printf("\n内存不够!\n");
    		exit(0);
    	}
    	r->next = NULL;
    	end->suitability = gen_group[0].suitability;   //头插法
    	end->next = r;
    	end = r;
    	log_num++;
     
    	x = gen_result.suitability - gen_group[0].suitability;
    	if(x < 0)x = -x;
    	if(x < error)      
    	{
    		result_unchange_time++;
    		if(result_unchange_time >= 20)return 1;  //连续20次循环中函数值改变量小于0.01,则终止进化
    	}
    	else     //记录最优解,继续变异
    	{
    		gen_result.info = gen_group[0].info;
    		gen_result.suitability = gen_group[0].suitability;
    		result_unchange_time=0;
    	}
    	return 0;
    }
     
    void mutation()       //变异函数
    {
    	int i , j , m;
    	float x;
    	float gmp;
    	int gentinfo;
    	float gentsuitability;
    	gmp = 1 - pow(1 - mp , 11);//在基因变异概率为mp时整条染色体的变异概率
    	for(i = 0 ; i < SUM ; i++)
    	{
    		if(randsign(gmp) == 1)  //randsign(gmp)按概率gmp产生1
    		{
    			j = randbit(0 , 14);  //产生0~14的随机数
    			m = 1 << j;     //整数i左移j位
    			gen_group[i].info = gen_group[i].info^m;
    			x = convertionB2D(gen_group[i].info);  //解码
    			gen_group[i].suitability = x*(x*(x*(x*(x*(x-10)-26)+344)+193)-1846)-1680;
    		}
    	}
    	for(i = 0 ; i < SUM - 1 ; i++)   //递增排序
    	{
    		for(j = i + 1 ; j < SUM ; j++)
    		{
    			if(gen_group[i].suitability > gen_group[j].suitability)
    			{
    				gentinfo = gen_group[i].info;
    				gen_group[i].info = gen_group[j].info;
    				gen_group[j].info = gentinfo;
    				
    				gentsuitability = gen_group[i].suitability;
    				gen_group[i].suitability = gen_group[j].suitability;
    				gen_group[j].suitability = gentsuitability;
    			}
    		}
    	}
    	/*
    	*为了提高执行速度,在进行变异操作的时候并没有直接确定需要进行变异的位
    	*而是先以cmp概率确定将要发生变异的染色体,再从染色体中随进选择一个基因进行变异
    	*由于进行选择和变异后父代种群的次序已被打乱,因此,在变异前后对种群进行一次排序
    	*/ 
    }
     
    void showresult(int flag)//显示搜索结果并释放内存
    {
    	int i , j;
    	struct log *logprint,*logfree;
    	FILE *logf;
    	if(flag == 0)
    		printf("已到最大搜索次数,搜索失败!");
    	else 
    	{
    		printf("当取值%f时表达式达到最小值为%f\n",convertionB2D(gen_result.info),gen_result.suitability);
    		printf("收敛过程记录于文件log.txt");
    		if((logf = fopen("log.txt" , "w+")) == NULL)
    		{
    			printf("Cannot create/open file");
    			exit(1);
    		}
    		logprint=head;
    		for(i = 0 ; i < log_num ; i = i + 5)//对收敛过程进行显示
    		{
    			for(j = 0 ; (j < 5) & ((i + j) < log_num-1) ; j++)
    			{
    				fprintf(logf , "%20f" , logprint->suitability);
    				logprint=logprint->next;				
    			}
    			fprintf(logf,"\n\n");
    		}
    	}
    	for(i = 0 ; i< log_num ; i++)//释放内存
    	{
    		logfree=head;
    		head=head->next;
    		free(logfree);
    		fclose(logf);
    	}
    	getchar();
    }
     
    int randsign(float p)   //按概率p返回1
    {
    	if(rand() > (p * 32768))
    		return 0;
    	else return 1;
    }
    
    int randbit(int i, int j)  //产生在i与j之间的一个随机数
    {
    	int a , l;
    	l = j - i + 1;
    	a = i + rand() * l / 32768;
    	return a;
    }
    
    int randnum()    //产生随机数
    {
    	int x;
    	x = rand() / 2;
    	return x;
    }
    
    float convertionB2D(int x)    //解码
    {
    	float y;
    	y = x;
    	y = (y - 8192) / 1000;
    	return y;
    }
    
    int convertionD2B(float x)     //编码
    {
    	int g;
    	g = (x * 1000) + 8192;
    	return g;
    }
    
    int createmask(int a)   //交叉操作
    {
    	int mask;
    	mask=(1 << (a + 1)) - 1;
    	return mask;
    }
    
    
    • **rand(T):**获取当前系统时间,当参数T为空指针(NULL)时,只返回值(从1900年1月1日到现在的时间秒数。

    • **srand():**设置一个随机种子,每次运行能保证随机种子不同。
      注:rand()是根据种子为基准以某个递推公式推算出来的一个伪随机数。(种子在计算机开机后就定了的,只能通过srand()函数来改变种子值。

    • 编码:x取值范围[-8,8],取数间隔为0.001,编码是将float数转换为int数,所以将x1000,范围变成[-8000,8000]。又因为前面已经选定用14位二进制来表示变量,但是负数和整数符号位不同,14位不足以表示,所以要**x1000+2^13**,使得转换所得均为正整数,可以用14位二进制编码表示。

    程序运行结果

    计算结果

    这里写图片描述

    展开全文
  • 多项式合集

    2021-02-20 19:13:39
    文章目录多项式合集拉格朗日插值问题背景结论推导拉格朗日插值与范德蒙矩阵开始全家桶之前形式化定义界快速傅里叶变换(FFT)复数基础从欧拉公式到单位圆多项式的表示法单位复数根DFTIDFT位逆序置换多项式乘法的实现...

    多项式合集

    拉格朗日插值

    问题背景

    给出 nn 个点 (xi,yi)(x_i,y_i),令这 nn 个点确定的多项式为 L(x)L(x),求 L(k)mod998244353L(k)\bmod 998244353 的值。

    结论

    L(x)=i=1nyili(x) L(x) = \sum_{i=1}^n y_il_i(x)

    其中每个 li(x)l_i(x) 为拉格朗日基本多项式,表达式为

    li(x)=j=1,jinxxjxixj l_i(x) = \prod_{j=1,j\ne i}^n\frac{x-x_j}{x_i-x_j}

    其特点是 li(xi)=1l_i(x_i)=1ji\forall j\ne ili(xj)=0l_i(x_j)=0

    推导

    抛开拉插,这道题明显可以列方程组然后使用高斯消元求解,但是复杂度为 O(n3)O(n^3) 且精度问题明显,所以拉格朗日是这样考虑的:

    对于每个点 Pi(xi,yi)P_i(x_i,y_i),构造一个 n1n-1 次多项式 li(x)l_i(x) 使其在 xix_i 上取值为 11,在其余 xjx_j 上为 00。构造的结果就是上面的结论:

    li(x)=j=1,jinxxjxixj l_i(x) = \prod_{j=1,j\ne i}^n\frac{x-x_j}{x_i-x_j}

    这个多项式的正确性还是很显然的。然后我们也知道这个多项式它就是唯一的。

    然后考虑构造答案:很显然对于点 Pi(xi,yi)P_i(x_i,y_i),只有 li(xi)l_i(x_i) 的取值为 11,其他的都为 00。所以答案的正确性也是比较显然的:对于 xix_i,只有 yili(xi)y_il_i(x_i) 产生了贡献,其余的都是 00。故这个多项式是正确的。

    所以回到一开始,我们需要的就是

    f(k)=i=1nyij=1,jinkxjxixj f(k) = \sum_{i=1}^n y_i\prod_{j=1,j\ne i}^n\frac{k-x_j}{x_i-x_j}

    由于模数是质数,所以使用费马小定理求逆元,跑得飞快。

    复杂度 O(n2)O(n^2),求逆元就是个很小的常数

    #include <cstdio>
    #include <cctype>
    #define il inline
    
    typedef long long ll;
    
    inline ll read()
    {
        char c = getchar();
        ll s = 0;
        bool x = 0;
        while (!isdigit(c))
            x = x | (c == '-'), c = getchar();
        while (isdigit(c))
            s = 10 * s + c - '0', c = getchar();
        return x ? -s : s;
    }
    
    const ll maxn = 2e3 + 5, mod = 998244353;
    ll x[maxn], y[maxn];
    
    ll pow(ll base, ll p)
    {
        ll ans = 1;
        base = (base + mod) % mod;
        for (; p; p >>= 1)
        {
            if (p & 1)
                ans = ans * base % mod;
            base = base * base % mod;
        }
        return ans;
    }
    
    il ll inv(ll n)
    {
        return pow(n, mod - 2);
    }
    
    int main()
    {
        ll n = read(), k = read();
        for (int i = 1 ; i <= n; ++i)
            x[i] = read(), y[i] = read();
        ll ans = 0;
        for (int i = 1; i <= n; ++i)
        {
            ll prod1 = 1, prod2 = 1;
            for (int j = 1; j <= n; ++j)
            {
                if (i == j)
                    continue;
                prod1 = prod1 * (k - x[j]) % mod;
                prod2 = prod2 * (x[i] - x[j]) % mod;
            }
            ans = (ans + prod1 * y[i] % mod * inv(prod2) % mod + mod) % mod;
        }
        printf("%lld\n", ans);
        return 0;
    }
    

    拉格朗日插值与范德蒙矩阵

    可以考虑将这 n+1n+1 个点值表示为如下形式:

    [x00x01x02x0nx10x11x12x1nxn0xn1xn2xnn][a0a1an]=[y0y1yn] \begin{bmatrix} x_0^0 & x_0^1 & x_0^2 &\cdots &x_0^n\\ x_1^0 & x_1^1 & x_1^2 &\cdots & x_1^n\\ \vdots & \vdots & \vdots & &\vdots\\ x_n^0 & x_n^1 & x_n^2 & \cdots & x_n^n \end{bmatrix} \begin{bmatrix} a_0\\a_1\\ \vdots \\ a_n \end{bmatrix}=\begin{bmatrix} y_0\\y_1\\ \vdots\\ y_n \end{bmatrix}

    左边这个矩阵就是所谓的范德蒙德矩阵,记作 V\boldsymbol V,系数列向量记作 A\boldsymbol A,右边的记作 B\boldsymbol B,则很明显:

    VA=B \boldsymbol{VA} = \boldsymbol B

    打开来看清楚些实际就是多项式 ff 在每个点处的值:

    yj=f(xj)=i=0naixji y_j = f(x_j) = \sum_{i = 0}^na_ix_j^i

    我们把两边都乘上 V1\boldsymbol V^{-1}

    [a0a1an]=V1[y0y1yn] \begin{bmatrix} a_0\\a_1\\ \vdots \\ a_n \end{bmatrix}=\boldsymbol V^{-1} \begin{bmatrix} y_0\\y_1\\ \vdots\\ y_n \end{bmatrix}

    就得到了 aia_i 一定可以表示为某种形如

    ak=[]yk a_k = \sum \begin{bmatrix} \vdots \end{bmatrix}y_k

    的形式,即 aka_k 只与 xix_iyky_k 有关。

    所以不难发现对于一个要求的 f(x)f(x_\ominus),都可以被表示为如下形式

    f(x)=δk(x)yk f(x_\ominus)=\sum\delta_k(x_\ominus)y_k

    δk(x)\delta_k(x) 构造的过程即需要考虑 x=xkx=x_kδj(x)=0δk(x)=1\delta_j(x) = 0\land\delta_k(x) = 1,其中 kjk\not=jδj(xk)=0\delta_j(x_k) = 0 说明每一个 δj\delta_j 都要有 (xxk)(x-x_k) 这个因式,然后又因为 δk(xk)=1\delta_k(x_k) = 1,所以最终构造出来就是上面的结果:

    f(x)=i=1nyij=1,jinxxjxixj f(x) = \sum_{i=1}^n y_i\prod_{j=1,j\ne i}^n\frac{x-x_j}{x_i-x_j}

    我们其实也可以利用拉格朗日插值来求范德蒙矩阵的逆阵,复杂度 O(n2)O(n^2)

    开始全家桶之前

    形式化定义

    约定:fif_i 表示 f(x)f(x)xix^i 处的系数,即一个多项式可以表示为 i=0fixi\displaystyle\sum_{i = 0} f_ix^i 的形式。

    两个多项式的加减法定义为

    f(x)±g(x)=i=0(fi±gi)xi f(x) \pm g(x) = \sum_{i = 0}(f_i \pm g_i)x^i

    复杂度 O(n)O(n)

    两个多项式的乘法(加法卷积)定义为:

    f(x)g(x)=i=0xij=0fjgij f(x)*g(x) = \sum_{i = 0}x^i\sum_{j = 0}f_jg_{i - j}

    不难发现其正确性。可以手动模拟一下多项式的乘法看看是不是这样子的。其本质也就是卷完之后合并同类项。朴素的做的话复杂度是 O(n2)O(n^2) 的,下面要讲的 FFT/NTT 可以加速到 O(nlogn)O(n\log n)

    有些时候,题目只对多项式的前若干项感兴趣,所以我们给运算设定一个上界,即 (modxn)\pmod{x^n}。意思就是只考虑这个多项式的前 nn,从 xnx^n 开始以后的全部舍掉。

    不难发现由加法和乘法是从低位到高位贡献的,所以

    (f(x)modxn±g(x)modxn)modxn=(f(x)±g(x))modxn(f(x)modxn)(g(x)modxn)modxn=(f(x)g(x))modxn \begin{aligned} (f(x) \bmod{x^n} \pm g(x)\bmod{x^n})\bmod{x^n} &= (f(x) \pm g(x))\bmod{x^n}\\ (f(x) \bmod{x^n}) * (g(x)\bmod{x^n})\bmod{x^n} &= (f(x) * g(x))\bmod{x^n}\\ \end{aligned}

    下面我们就开始学习多项式的各种操作吧

    快速傅里叶变换(FFT)

    FFT 可以加速卷积,让时间复杂度从 O(n2)O(n^2) 降到 O(nlogn)O(n\log n),学习 FFT 的基础操作前,需要先了解复数,因为 FFT 就是基于单位复数根的良好性质实现的。

    复数基础

    (数学选修 2-2 内容)

    定义虚数单位 i2=1\mathrm i^2 = \sqrt{-1},把形如 a+bi(a,bR)a + b\mathrm i\:(a,b\in\mathbb R) 的数称为复数,所有复数的集合称为复数集 C\mathbb C

    复数一般使用 zz 表示,表示为 z=a+biz = a + b\mathrm i,这种形式称为复数的代数形式。aa 被称为复数的实部,bb 称为复数的虚部,未加说明的情况下一般认为 a,bRa,b\in\mathbb R。很明显地,当 a=0b0a = 0\land b\not=0 时,这个复数为纯虚数,当 b=0b=0 时,这个复数为实数。

    每个复数 a+bia + b\mathrm i 都能对应平面直角坐标系里面的一个点 (a,b)(a,b),同样的也可以对应一个向量 (a,b)(a,b)。故定义复数的模为 a2+b2\sqrt{a^2 + b^2}

    定义复数的加法与乘法:
    (a+bi)+(c+di)=(a+c)+(b+d)i \begin{aligned} &(a + b\mathrm i) + (c + d\mathrm i)\\ =&(a + c) + (b + d)\mathrm i \end{aligned}

    (a+bi)(c+di)=ac+adi+cbi+bdi2=(acbd)+(ad+bc)i \begin{aligned} &(a+b\mathrm i)(c + d\mathrm i)\\ =&ac + ad\mathrm i + cb\mathrm i + bd\mathrm i^2\\ =&(ac - bd) + (ad + bc)\mathrm i \end{aligned}

    这都是比较显然的。

    容易看出复数满足很多实数的运算律。

    定义复数 z=a+biz=a+b\mathrm i 的共轭复数为 z=abi\overline{z} = a - b\mathrm i,不难发现 zzz\overline{z} 关于实轴对称。
    zz=(a+bi)(abi)=a2+b2=z2 z\overline z=(a+b\mathrm i)(a-b\mathrm i) = a^2 + b^2=|z|^2
    复数既然可以对应平面直角坐标系中的向量,不难发现其可以使用其模长与辐角来表示:
    z=a+bi    z=r(cosθ+isinθ) z=a+b\mathrm i\iff z = r(\cos\theta+\mathrm i\sin\theta)
    其中 rrzz 的模长,θ\theta 为其辐角。即我们可以把一个复数表示成二元组 (r,θ)(r,\theta) 的形式。

    现在考虑两个复数 (r1,θ1)(r_1,\theta_1)(r2,θ2)(r_2,\theta_2) 相乘得到的结果:
    (r1,θ1)×(r2,θ2)=r1(cosθ1+isinθ1)r2(cosθ2+isinθ2)=r1r2(cosθ1cosθ2sinθ1sinθ2+isinθ1cosθ2+isinθ2cosθ1)=r1r2(cos(θ1+θ2)+isin(θ1+θ2))=(r1r2,θ1+θ2) \begin{aligned} (r_1,\theta_1)\times(r_2,\theta_2) &= r_1(\cos\theta_1 + \mathrm i\sin\theta_1)r_2(\cos\theta_2 + \mathrm i\sin\theta_2)\\ &=r_1r_2(\cos\theta_1\cos\theta_2 - \sin\theta_1\sin\theta_2 + \mathrm i\sin\theta_1\cos\theta_2 + \mathrm i\sin\theta_2\cos\theta_1)\\ &=r_1r_2\left(\cos(\theta_1 + \theta_2) + \mathrm i\sin(\theta_1 + \theta_2)\right)\\ &=(r_1r_2,\theta_1 + \theta_2) \end{aligned}
    于是我们可以概括复数乘法的法则:模长相乘,辐角相加。(上述推导需要掌握基本的三角恒等变换)

    从欧拉公式到单位圆

    给出复数指数幂的定义:
    ex+yi=ex(cosy+isiny) \mathrm e^{x +y\mathrm i} = e^x(\cos y + \mathrm i\sin y)
    这个公式是由我也不会证明的泰勒展开推导出来的:
    sin(x)=xx33!+x55!x77!+x99!+=k=1(1)k1x2k1(2k1)!cos(x)=1x22!+x44!x66!+x88!+=k=0(1)kx2k(2k)!ex=1+x+x22!+x33!+x44!+=k=0xkk! \begin{aligned} \sin(x) &= x - \frac{x^3}{3!}+\frac{x^5}{5!} - \frac{x^7}{7!} + \frac{x^9}{9!} + \cdots = \sum_{k = 1}^\infin\frac{(-1)^{k - 1}x^{2k - 1}}{(2k-1)!}\\ \cos(x) &= 1 - \frac{x^2}{2!} + \frac{x^4}{4!} - \frac{x^6}{6!} + \frac{x^8}{8!} + \cdots = \sum_{k = 0}^\infin\frac{(-1)^{k} x^{2k}}{(2k)!}\\ \mathrm e^x &= 1 + x + \frac{x^2}{2!} + \frac{x^3}{3!} + \frac{x^4}{4!} + \cdots = \sum_{k = 0}^\infin\frac{x^k}{k!} \end{aligned}
    x+yix + y\mathrm i 代入进去即可推导。

    如果 x=0x = 0,我们就得到大名鼎鼎的欧拉公式:
    exi=cosx+isinx \mathrm e^{x\mathrm i} = \cos x + \mathrm i\sin x
    更特殊地,如果 x=πx = \pi,得到的就是下面这个神奇的式子:
    eπi=1 \mathrm e^{\pi\mathrm i} = -1
    复平面上我们可以定义类似于平面直角坐标系上的单位圆,单位圆上的所有复数构成集合 {z:z=1}\{z: |z| = 1\}。这些复数都可以表示为 cosθ+isinθ\cos\theta + \mathrm i\sin\thetaeθie^{\theta \mathrm i} 的形式。

    多项式的表示法

    系数表示法:顾名思义
    f(x)=a0+a1x+a2x2++anxn    f(x)={a0,a1,a2,,an}=[x0x1x2xn][a0a1a2an] f(x) = a_0 + a_1x + a_2x^2 + \cdots + a_nx^n\iff f(x) = \{a_0,a_1,a_2,\cdots,a_n\} = \begin{bmatrix} x^0 & x^1 & x^2 &\cdots & x^n \end{bmatrix} \begin{bmatrix} a_0\\a_1\\a_2\\\vdots\\a_n \end{bmatrix}
    点值表示法:

    我们知道由一个多项式在 n+1n + 1 个点上的取值是可以唯一确定一个多项式的,其本质也就是线性方程组的解。所以一个 nn 次多项式可以用 n+1n + 1 个点表示:

    f(x)={(x0,y0),(x1,y1),,(xn,yn)} f(x) = \{(x_0,y_0),(x_1,y_1),\cdots,(x_n,y_n)\}

    或者:

    [x00x01x02x0nx10x11x12x1nxn0xn1xn2xnn][a0a1an]=[y0y1yn] \begin{bmatrix} x_0^0 & x_0^1 & x_0^2 &\cdots &x_0^n\\ x_1^0 & x_1^1 & x_1^2 &\cdots & x_1^n\\ \vdots & \vdots & \vdots & &\vdots\\ x_n^0 & x_n^1 & x_n^2 & \cdots & x_n^n \end{bmatrix}\begin{bmatrix} a_0\\a_1\\ \vdots \\ a_n \end{bmatrix} =\begin{bmatrix} y_0\\y_1\\ \vdots\\ y_n \end{bmatrix}

    通过下面的这个形式我们看得出来其就是一个典型的线性方程组的形式,不难证明其解的唯一性。

    并且我们发现点值表示法有一个很明显的优势:可以在 O(n)O(n) 的时间内将两个多项式乘起来,只需把对应点的 yy 乘起来即可。

    通俗的来说,FFT 实现的就是快速求多项式乘法的过程:先把系数表示法转成点值表示法(DFT,离散傅里叶变换),乘完之后再把点值还原为插值(IDFT,离散傅里叶逆变换)。可是朴素的 DFT 需要的时间复杂度为 O(n2)O(n^2),IDFT 还回其系数需要高斯消元是 O(n3)O(n^3) 的。而 FFT 利用了一些很特殊很特殊的值加速了 DFT 和 IDFT 的过程,使得总时间复杂度降低到了 O(nlogn)O(n\log n)

    单位复数根

    解这个方程:
    xn=1 x^n = 1
    我们会发现这个方程在实数范围内只有 11 或者 22 个解。然而代数基本定理告诉我们这样的方程有 nn 个复数域上的解。由模长相乘辐角相加我们知道因为最终 xn=1x^n = 1,所以这些满足条件的 xx 的模长必定也是 11。然后需要满足他们的辐角的 nn 倍能被 2π2\pi 整除。

    不难发现其就是 nn 等分单位圆:

    img

    我们记 nn 次单位根的第 kk 个记为 ωnk\omega_n^k,不难发现 ωkn=e2kπin\omega_k^n = \mathrm e^{\frac{2k\pi i}{n}}。由此可见,单位复数根具有一些非常好的性质比如:
    ωn0=ωnn=1ωnk=ω2n2kω2nk+n=ω2nk(ω2nk+n)2=ωnk \begin{aligned} \omega_n^0 = \omega_n^n &= 1\\ \omega_n^k &= \omega_{2n}^{2k}\\ \omega_{2n}^{k + n} &= -\omega_{2n}^k\\ \left(\omega_{2n}^{k + n}\right)^2 &=\omega_n^k \end{aligned}
    利用这些性质,我们可以加速 DFT 的过程。FFT 就是利用分治思想加速求每个 f(ωnk)f(\omega_n^k) 的值

    DFT

    此时 DFT 的分治思想就是分开考虑奇次项和偶次项:

    考虑
    f(x)=a0x0+a1x1+a2x2+ f(x) = a_0x^0 + a_1x^1 + a_2x^2 + \cdots
    将其分为两个多项式
    f(x)=a0x0+a2x2+a4x4+a6x6+a8x8++a1x1+a3x3+a5x5+a7x7+a9x9+=a0x0+a2x2+a4x4+a6x6+a8x8++x(a1x0+a3x2+a5x4+a7x6+) \begin{aligned} f(x) &= a_0x^0 + a_2x^2 + a_4x^4 + a_6x^6 + a_8x^8 + \cdots +a_1x^1 + a_3x^3 + a_5x^5 + a_7x^7 + a_9x^9 + \cdots\\ &= a_0x^0 + a_2x^2 + a_4x^4 + a_6x^6 + a_8x^8+\cdots +x(a_1x^0 + a_3x^2 + a_5x^4 + a_7x^6 + \cdots) \end{aligned}
    考虑两个新多项式:
    f0(x)=a0x0+a2x1+a4x2+a6x3+f1(x)=a1x0+a3x1+a5x2+a7x3+ \begin{aligned} f_0(x) &= a_0x^0 + a_2x^1 + a_4x^2 + a_6x^3 + \cdots\\ f_1(x) &= a_1x^0 + a_3x^1 + a_5x^2 + a_7x^3 + \cdots \end{aligned}
    不难发现
    f(x)=f0(x2)+xf1(x2) f(x) = f_0(x^2) + xf_1(x^2)
    利用单位复数根的性质:
    DFT(f(ωnk))=DFT(f0(ωn2k))+ωnkDFT(f1(ωn2k))=DFT(f0(ωn2k))+ωnkDFT(f1(ωn2k)) \begin{aligned} \mathrm{DFT}(f(\omega_n^k)) &= \mathrm{DFT}(f_0(\omega_n^{2k})) + \omega_n^k\mathrm{DFT}(f_1(\omega_n^{2k}))\\ &=\mathrm{DFT}(f_0(\omega_\frac n2^k)) + \omega_n^k\mathrm{DFT}(f_1(\omega_\frac n2^k)) \end{aligned}

    DFT(f(ωnk+n2))=DFT(f0(ωn2k+n))+ωnk+n2DFT(f1(ωn2k+n))=DFT(f0(ωnnωn2k))ωnkDFT(f1(ωnnωn2k))=DFT(f0(ωn2k))ωnkDFT(f1(ωn2k)) \begin{aligned} \mathrm{DFT}(f(\omega_n^{k + \frac n2})) &= \mathrm{DFT}(f_0(\omega_n^{2k + n})) + \omega_{n}^{k + \frac n2}\mathrm{DFT}(f_1(\omega_n^{2k + n}))\\ &=\mathrm{DFT}(f_0(\omega_n^n\omega_n^{2k})) - \omega_n^k\mathrm{DFT}(f_1(\omega_n^n\omega_n^{2k}))\\ &=\mathrm{DFT}(f_0(\omega_\frac n2^k)) - \omega_n^k\mathrm{DFT}(f_1(\omega_\frac n2^k)) \end{aligned}

    其中 k<n2k < \displaystyle\frac n2。不难发现只要我们求得出 DFT(f0(ωn2k))\mathrm{DFT}(f_0(\omega_\frac n2^k))DFT(f1(ωn2k))\mathrm{DFT}(f_1(\omega_\frac n2^k)) 的话,就可以同时求出 DFT(f(ωnk))\mathrm{DFT}(f(\omega_n^k))DFT(f(ωnk+n2))\mathrm{DFT}(f(\omega_n^{k + \frac n2}))。接下来再对 f0f_0f1f_1 递归 DFT 即可。其时间复杂度函数是形如下面这样的:
    T(n)=T(n/2)+O(n) T(n) = T(n/2) + O(n)
    所以总复杂度为 Θ(nlogn)\Theta(n\log n)

    实际实现的时候一定要注意传进去的系数一定要是 2m2^m 个的,不然分治的过程中左右不一样会出问题。第一次传进去的时候就高位补 00,补成最高项次数为 2m12^{m - 1} 的多项式。

    void dft(int lim, complex *a)
    {
        if (lim == 1) return;//常数项直接返回
        complex a1[lim >> 1], a2[lim >> 1];
        for (int i = 0; i < lim; i += 2)
            a1[i >> 1] = a[i], a2[i >> 1] = a[i + 1];//把系数按照奇偶分开
        dft(lim >> 1, a1, type);//求 DFT(f_0())
        dft(lim >> 1, a2, type);//求 DFT(f_1())
        complex Wn = complex(cos(2.0 * pi / lim), sin(2.0 * pi / lim)), w = complex(1, 0);
        for (int i = 0; i < (lim >> 1); ++i, w = w * Wn)
        {
            a[i] = a1[i] + w * a2[i];//求 DFT(f(\omega_n^k))
            a[i + (lim >> 1)] = a1[i] - w * a2[i];//求 DFT(f(\omega_n^{k+\fracn2}))
        }
        return;
    }
    

    IDFT

    好了现在假装我们已经求出了两个多项式的点值表达并已经将他们乘起来,但是我们最终还是要把他还原回去到系数表示的。这个过程就叫做 IDFT。

    其实就是我们需要求解下面关于 aa 的线性方程组:

    [(ωn0)0(ωn0)1(ωn0)2(ωn0)n(ωn1)0(ωn1)1(ωn1)2(ωn1)n(ωnn)0(ωnn)1(ωnn)2(ωnn)n][a0a1an]=[y0y1yn] \begin{bmatrix} (\omega_n^0)^0 & (\omega_n^0)^1 & (\omega_n^0)^2 &\cdots &(\omega_n^0)^n\\ (\omega_n^1)^0 & (\omega_n^1)^1 & (\omega_n^1)^2 &\cdots & (\omega_n^1)^n\\ \vdots & \vdots & \vdots & &\vdots\\ (\omega_n^{n})^0 & (\omega_n^{n})^1 & (\omega_n^{n})^2 & \cdots & (\omega_n^n)^n \end{bmatrix} \begin{bmatrix} a_0\\a_1\\ \vdots \\ a_n \end{bmatrix}=\begin{bmatrix} y_0\\y_1\\ \vdots\\ y_n \end{bmatrix}

    我们将其乘上左边矩阵的逆:

    [a0a1an]=[(ωn0)0(ωn0)1(ωn0)2(ωn0)n(ωn1)0(ωn1)1(ωn1)2(ωn1)n(ωnn)0(ωnn)1(ωnn)2(ωnn)n]1[y0y1yn] \begin{bmatrix} a_0\\a_1\\ \vdots \\ a_n \end{bmatrix}=\begin{bmatrix} (\omega_n^0)^0 & (\omega_n^0)^1 & (\omega_n^0)^2 &\cdots &(\omega_n^0)^n\\ (\omega_n^1)^0 & (\omega_n^1)^1 & (\omega_n^1)^2 &\cdots & (\omega_n^1)^n\\ \vdots & \vdots & \vdots & &\vdots\\ (\omega_n^{n})^0 & (\omega_n^{n})^1 & (\omega_n^{n})^2 & \cdots & (\omega_n^n)^n \end{bmatrix}^{-1} \begin{bmatrix} y_0\\y_1\\ \vdots\\ y_n \end{bmatrix}

    模相同的正交列向量构成的矩阵的逆是转置的模分之一倍,所以:

    [(ωn0)0(ωn0)1(ωn0)2(ωn0)n(ωn1)0(ωn1)1(ωn1)2(ωn1)n(ωnn)0(ωnn)1(ωnn)2(ωnn)n]1=1n+1[(ωn0)0(ωn0)1(ωn0)2(ωn0)n(ωn1)0(ωn1)1(ωn1)2(ωn1)n(ωnn)0(ωnn)1(ωnn)2(ωnn)n] \begin{bmatrix} (\omega_n^0)^0 & (\omega_n^0)^1 & (\omega_n^0)^2 &\cdots &(\omega_n^0)^n\\ (\omega_n^1)^0 & (\omega_n^1)^1 & (\omega_n^1)^2 &\cdots & (\omega_n^1)^n\\ \vdots & \vdots & \vdots & &\vdots\\ (\omega_n^{n})^0 & (\omega_n^{n})^1 & (\omega_n^{n})^2 & \cdots & (\omega_n^n)^n \end{bmatrix}^{-1} =\frac {1}{n+1} \begin{bmatrix} (\omega_n^{-0})^0 & (\omega_n^{-0})^1 & (\omega_n^{-0})^2 &\cdots &(\omega_n^{-0})^n\\ (\omega_n^{-1})^0 & (\omega_n^{-1})^1 & (\omega_n^{-1})^2 &\cdots & (\omega_n^{-1})^n\\ \vdots & \vdots & \vdots & &\vdots\\ (\omega_n^{-n})^0 & (\omega_n^{-n})^1 & (\omega_n^{-n})^2 & \cdots & (\omega_n^{-n})^n \end{bmatrix}

    所以不难发现,IDFT 其实就是再做了一遍 DFT,只不过是反起来的。只是算出来最后的系数结果都要除以点值的个数,反应在代码里面就是那个 lim 变量。

    不难发现 ωnk\omega_n^k 的共轭就是虚部取反,所以可以在 DFT 函数里面传一个参数表示是否为 IDFT。

    这样子一个递归版的 FFT 就写完了,总体的代码如下:

    #include <cstdio>
    #include <cctype>
    #include <cmath>
    #define FOR(i, a, b) for (int i = a; i <= b; ++i)
    
    const int maxn = 2e6 + 5;
    const double pi = acos(-1.0);
    
    inline int read()
    {
        char c = getchar();
        int s = 0;
        while (!isdigit(c))
            c = getchar();
        while (isdigit(c))
            s = 10 * s + c - '0', c = getchar();
        return s;
    }
    
    struct complex
    {
        double x, y;
        complex(double xx = 0, double yy = 0)
        {
            x = xx, y = yy;
        }
    } a[maxn], b[maxn];
    
    complex operator+(const complex &a, const complex &b) {return complex(a.x + b.x, a.y + b.y);}
    complex operator-(const complex &a, const complex &b) {return complex(a.x - b.x, a.y - b.y);}
    complex operator*(const complex &a, const complex &b) {return complex(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);}
    
    void dft(int lim, complex *a, int type)//type = 1 DFT;type = -1 IDFT
    {
        if (lim == 1) return;//返回常数项
        complex a1[lim >> 1], a2[lim >> 1];
        for (int i = 0; i < lim; i += 2)
            a1[i >> 1] = a[i], a2[i >> 1] = a[i + 1];
        dft(lim >> 1, a1, type);
        dft(lim >> 1, a2, type);
        complex Wn = complex(cos(2.0 * pi / lim), type * sin(2.0 * pi / lim)), w = complex(1, 0);
        for (int i = 0; i < (lim >> 1); ++i, w = w * Wn)
        {
            a[i] = a1[i] + w * a2[i];
            a[i + (lim >> 1)] = a1[i] - w * a2[i];
        }
        return;
    }
    
    int main()
    {
        int n = read(), m = read();
        FOR(i, 0, n) a[i].x = read();
        FOR(i, 0, m) b[i].x = read();
        int lim = 1;
        while (lim <= n + m) lim <<= 1;//lim一定要大于 n + m
        dft(lim, a, 1);
        dft(lim, b, 1);
        FOR(i, 0, lim)
            a[i] = a[i] * b[i];//点值乘起来
        dft(lim, a, -1);//IDFT还回去
        FOR(i, 0, n + m)
            printf("%d ", (int)(a[i].x / lim + 0.5));//最后要除那个数然后还原回去,四舍五入
        return 0;
    }
    

    位逆序置换

    然而,上面的代码连模板都跑不过去……

    考虑继续优化 DFT 的过程。递归的过程中开了大量的空间并且常数巨大,考虑非递归写法。

    只考虑我们对 0077 操作:

    递归的过程:

    original		0	1	2	3	4	5	6	7
    recursion#1		0	2	4	6	1	3	5	7
    recursion#2		0	4	2	6	1	5	3	7
    recursion#3		0	4	2	6	1	5	3	7
    original bin	000	001	010	011	100	101	110	111
    now bin			000	100	010	110	001	101	011	111
    

    可见递归到最后的结果无非就是一个二进制反转。

    所以我们可以考虑非递归,一开始就先把所有的数放到最后的位置,然后迭代的时候一步步还回去即可。这个过程就是位逆序置换(蝴蝶变换)

    考虑处理出 xx 二进制位翻转之后的数 R(x)R(x)。易知 R(0)=0R(0) = 0。我们可以从小到大求 R(x)R(x)。很明显,x/2\lfloor x/2\rfloor 的二进制位是 xx 右移一位,那么如果知道了 R(x/2)R(\lfloor x/2\rfloor) 就可以很容易的求出 R(x)R(x),再分 xx 的奇偶性判断就可以了。
    R(x)=R(x/2)2+(xmod2)×len2 R(x) = \left\lfloor\frac{R(\lfloor x/2\rfloor)}{2}\right\rfloor + (x\bmod 2)\times\frac{len}2
    举个例子:翻转 (10101110)2(10101110)_2,首先我们知道它的二分之一倍为 (01010111)2(01010111)_2,其翻转结果为 (11101010)2(11101010)_2,除以二变为 (01110101)2(01110101)_2,由于它是偶数所以前面不用补 11。不难发现其就是一开始要求的翻转结果。

    预处理翻转结果的代码:

    while (lim <= n + m) lim <<= 1;
    FOR(i, 0, lim - 1)
        rev[i] = ((rev[i >> 1] >> 1) | (((i & 1) ? (lim >> 1) : 0)));
    

    然后在处理翻转的时候只需要下面几行:

    FOR(i, 0, lim - 1)
        if (i < rev[i])
            myswap(a[i], a[rev[i]]);
    

    不难验证其正确性。

    而且观察我们在求 DFT(f(ωnk))\mathrm{DFT}(f(\omega_n^k)) 时我们需要算两遍 ωnkDFT(f1(ωn2k))\omega_n^k\mathrm{DFT}(f_1(\omega_\frac n2^k)),复数的乘法常数很大,考虑使用临时变量记录以降低常数。

    这样子的话迭代版的 DFT 过程就很好写了:

    void DFT(int lim, complex *a, int type)
    {
        FOR(i, 0, lim - 1)
            if (i < rev[i])
                myswap(a[i], a[rev[i]]);//先预处理翻转完了的结果
        for (int p = 2; p <= lim; p <<= 1)//模拟合并答案的过程,即为所谓的 n
        {
            int len = p >> 1;//即上面的 n / 2
            complex Wp = complex(cos(2 * pi / p), type * sin(2 * pi / p));//处理出 p 次单位根
            for (int k = 0; k < lim; k += p)//对每一个进行合并
            {
                complex w = complex(1, 0);//处理 \omega_p^0
                for (int l = k; l < k + len; ++l, w = w * Wp)//开始合并
                {
                    //此时的 a[l] 就是之前的 a1[i],a[len + l] 就是之前的 a2[i]
                    complex tmp = w * a[len + l];
                    a[len + l] = a[l] - tmp;//相当于上面的 a[i + (lim >> 1)] = a1[i] - w * a2[i]
                    a[l] = a[l] + tmp;//相当于上面的 a[i] = a1[i] + w * a2[i]
                }
            }
        }
    }
    

    多项式乘法的实现

    总的一个非递归版 FFT 的实现如下(洛谷 P3803):

    #include <cstdio>
    #include <cctype>
    #include <cmath>
    #define FOR(i, a, b) for (int i = a; i <= b; ++i)
    
    const int maxn = 3e6 + 5;
    const double pi = acos(-1.0);
    
    inline int read()
    {
        char c = getchar();
        int s = 0;
        while (!isdigit(c))
            c = getchar();
        while (isdigit(c))
            s = 10 * s + c - '0', c = getchar();
        return s;
    }
    
    template<typename T> inline void myswap(T &a, T &b)
    {
        T t = a;
        a = b;
        b = t;
        return;
    }
    
    struct complex
    {
        double x, y;
        complex(double xx = 0, double yy = 0)
        {
            x = xx, y = yy;
        }
    } a[maxn], b[maxn];
    
    int rev[maxn];
    
    complex operator+(const complex &a, const complex &b) {return complex(a.x + b.x, a.y + b.y);}
    complex operator-(const complex &a, const complex &b) {return complex(a.x - b.x, a.y - b.y);}
    complex operator*(const complex &a, const complex &b) {return complex(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);}
    
    void DFT(int lim, complex *a, int type)
    {
        FOR(i, 0, lim - 1)
            if (i < rev[i])
                myswap(a[i], a[rev[i]]);//先预处理翻转完了的结果
        for (int p = 2; p <= lim; p <<= 1)//模拟合并答案的过程,即为所谓的 n
        {
            int len = p >> 1;//即上面的 n / 2
            complex Wp = complex(cos(2 * pi / p), type * sin(2 * pi / p));//处理出 p 次单位根
            for (int k = 0; k < lim; k += p)//对每一个进行合并
            {
                complex w = complex(1, 0);//处理 \omega_p^0
                for (int l = k; l < k + len; ++l, w = w * Wp)//开始合并
                {
                    //此时的 a[l] 就是之前的 a1[i],a[len + l] 就是之前的 a2[i]
                    complex tmp = w * a[len + l];
                    a[len + l] = a[l] - tmp;//相当于上面的 a[i + (lim >> 1)] = a1[i] - w * a2[i]
                    a[l] = a[l] + tmp;//相当于上面的 a[i] = a1[i] + w * a2[i]
                }
            }
        }
    }
    
    int main()
    {
        int n = read(), m = read();
        FOR(i, 0, n) a[i].x = read();
        FOR(i, 0, m) b[i].x = read();
        int lim = 1;
        while (lim <= n + m) lim <<= 1;//补齐高位
        FOR(i, 0, lim - 1)
            rev[i] = ((rev[i >> 1] >> 1) | (((i & 1) ? (lim >> 1) : 0)));//先处理翻转完的结果
        DFT(lim, a, 1);//DFT
        DFT(lim, b, 1);//DFT
        FOR(i, 0, lim)
            a[i] = a[i] * b[i];//对处理出来的点值进行乘法
        DFT(lim, a, -1);//IDFT
        FOR(i, 0, n + m)
            printf("%d ", (int)(a[i].x / lim + 0.5));
        return 0;
    }
    

    使用 FFT 来求高精度整数乘法的实现(洛谷 P1919):

    #include <cstdio>
    #include <cstring>
    #include <cmath>
    #define FOR(i, a, b) for (int i = a; i <= b; ++i)
    #define DEC(i, a, b) for (int i = a; i >= b; --i)
    
    template<typename T> inline void myswap(T &a, T &b) {T t = a; a = b; b = t; return;}
    
    typedef double db;
    
    const int maxn = 3000000 + 5;
    const db pi = acos(-1.0);
    
    struct cmplx
    {
        db x, y;
        cmplx(db xx = 0, db yy = 0) {x = xx, y = yy;}
    } a[maxn], b[maxn];
    
    cmplx operator+(const cmplx &a, const cmplx &b) {return cmplx(a.x + b.x, a.y + b.y);}
    cmplx operator-(const cmplx &a, const cmplx &b) {return cmplx(a.x - b.x, a.y - b.y);}
    cmplx operator*(const cmplx &a, const cmplx &b) {return cmplx(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);}
    
    char s1[maxn], s2[maxn];
    int rev[maxn], ans[maxn];
    
    void DFT(cmplx *f, int lim, int type)
    {
        FOR(i, 0, lim - 1)
            if (i < rev[i])
                myswap(f[i], f[rev[i]]);
        for (int p = 2; p <= lim; p <<= 1)
        {
            int len = p >> 1;
            cmplx Wp(cos(2.0 * pi / p), type * sin(2.0 * pi / p));
            for (int k = 0; k < lim; k += p)
            {
                cmplx w(1, 0);
                for (int l = k; l < k + len; ++l, w = w * Wp)
                {
                    cmplx tmp = w * f[l + len];
                    f[l + len] = f[l] - tmp;
                    f[l] = f[l] + tmp;
                }
            }
        }
        return;
    }
    
    int main()
    {
        scanf("%s\n%s", s1, s2);
        int n1 = -1, n2 = -1;
        DEC(i, strlen(s1) - 1, 0) a[++n1].x = s1[i] - '0';
        DEC(i, strlen(s2) - 1, 0) b[++n2].x = s2[i] - '0';
        int lim = 1;
        while (lim <= n1 + n2) lim <<= 1;
        FOR(i, 0, lim - 1)
            rev[i] = ((rev[i >> 1] >> 1) | (((i & 1) ? (lim >> 1) : 0)));
        DFT(a, lim, 1);
        DFT(b, lim, 1);
        FOR(i, 0, lim)
            a[i] = a[i] * b[i];
        DFT(a, lim, -1);
        FOR(i, 0, lim)
            ans[i] = (int)(a[i].x / lim + 0.5);
        FOR(i, 0, lim)
            if (ans[i] >= 10) ans[i + 1] += ans[i] / 10, ans[i] %= 10, lim += (i == lim);
        while (!ans[lim] && lim > -1) --lim;
        if (lim == -1) puts("0");
        else DEC(i, lim, 0) printf("%d", ans[i]);
        return 0;
    }
    

    当然,千万要记得 IDFT 还回去的时候要除以 lim,实在怕记不住就在 DFT 函数里面加几句话直接处理好

    if (type == -1)
        FOR(i, 0, lim - 1)
            f[i].x /= lim;
    

    针对多项式乘法:三次变两次优化

    我们发现我们在做多项式乘法的时候,需要先 DFT A(x)A(x)B(x)B(x),乘在一起之后再 IDFT 还回来 C(x)C(x),一共进行了三次这样的操作。考虑如何减少我们调用 DFT 的次数。

    可以把 B(x)B(x) 的系数放到 A(x)A(x) 系数的虚部上面,即 a+bia + b\mathrm i,然后 DFT 一下 A(x)A(x) 再求个平方,得到 A2(x)A^2(x),再 IDFT 回去。我们可以发现得到的系数都是 (a+bi)2=a2b2+2abi(a + b\mathrm i)^2 = a^2 - b^2 + 2ab\mathrm i 的形式的,所以只需要取出虚部再除以二就得到答案了。

    这样的写法可以减小常数,跑的比 NTT 还快。

    #include <cstdio>
    #include <cctype>
    #include <cmath>
    #define FOR(i, a, b) for (int i = a; i <= b; ++i)
    
    typedef double db;
    
    const int maxn = 3e6 + 5;
    const db pi = acos(-1.0);
    
    inline int read()
    {
        char c = getchar();
        int s = 0;
        while (!isdigit(c))
            c = getchar();
        while (isdigit(c))
            s = 10 * s + c - '0', c = getchar();
        return s;
    }
    
    template<typename T> inline void myswap(T &a, T &b)
    {
        T t = a;
        a = b;
        b = t;
        return;
    }
    
    struct cmplx
    {
        db x, y;
        cmplx(db xx = 0, db yy = 0)
        {
            x = xx, y = yy;
        }
    } a[maxn];
    
    int rev[maxn];
    
    cmplx operator+(const cmplx &a, const cmplx &b) {return cmplx(a.x + b.x, a.y + b.y);}
    cmplx operator-(const cmplx &a, const cmplx &b) {return cmplx(a.x - b.x, a.y - b.y);}
    cmplx operator*(const cmplx &a, const cmplx &b) {return cmplx(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);}
    
    void DFT(cmplx *f, int lim, int type)
    {
        FOR(i, 0, lim - 1)
            if (i < rev[i])
                myswap(f[i], f[rev[i]]);
        for (int p = 2; p <= lim; p <<= 1)
        {
            int len = p >> 1;
            cmplx Wp(cos(2 * pi / p), type * sin(2 * pi / p));
            for (int k = 0; k < lim; k += p)
            {
                cmplx w(1, 0);
                for (int l = k; l < k + len; ++l, w = w * Wp)
                {
                    cmplx tmp = w * f[len + l];
                    f[len + l] = f[l] - tmp;
                    f[l] = f[l] + tmp;
                }
            }
        }
    }
    
    int main()
    {
        int n = read(), m = read();
        FOR(i, 0, n) a[i].x = read();
        FOR(i, 0, m) a[i].y = read();
        int lim = 1;
        while (lim <= n + m) lim <<= 1;
        FOR(i, 0, lim - 1)
            rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) ? (lim >> 1) : 0));
        DFT(a, lim, 1);
        FOR(i, 0, lim - 1)
            a[i] = a[i] * a[i];
        DFT(a, lim, -1);
        FOR(i, 0, n + m)
            printf("%d ", (int)(a[i].y / lim / 2.0 + 0.5));
        return 0;
    }
    

    快速数论变换(NTT)

    有了 FFT,我们已经有能力在 O(nlogn)O(n\log n) 的时间内求出两个多项式的卷积了。但是 FFT 也有它的缺点:复数采用的浮点运算不仅造成精度的问题,还会增大常数。遗憾的是数学家们已经证明了 C\mathbb C 中只有单位复数根满足 FFT 的要求。

    考虑到利用多项式的计数题很多都是模意义下的,所以自然希望为单位复数根找一个模意义下的替代品。此时就进入下面的前置知识:原根。

    原根

    设整数 r,nr,n 满足 rnr0n>0r\perp n\land r \not= 0 \land n > 0,使得 rx1(modn)r^x \equiv 1\pmod n最小正整数 xx 称为 rrnn,记为 ordnr\mathrm{ord}_nrδn(r)\delta_n(r)

    r,nN+rnr,n\in\mathbb N^+\land r\perp n,当 ordnr=ϕ(n)\operatorname{ord}_nr = \phi(n) 时,称 rr 是模 nn 的原根或者 nn 的原根。

    NTT

    对于质数 p=qn+1(n=2m)p = qn + 1\:(n = 2^m),原根 gg 满足 gqn1(modp)g^{qn}\equiv 1\pmod p,将 gn=gq(modp)g_n = g^q\pmod p 看作 ωn\omega_n 的等价,其满足相似的性质,比如 gnn1(modp),gnn/21(modp)g_n^n\equiv 1\pmod p,g_n^{n/2} \equiv -1\pmod p

    常见的质数
    p=998244353=7×17×223+1,g=3p=1004535809=479×221+1,g=3 \begin{aligned} p &= 998244353 = 7\times17\times2^{23} + 1,&g = 3\\ p &= 1004535809 = 479\times 2^{21} + 1,&g = 3 \end{aligned}
    迭代到长度为 ll 时,gl=gp1lg_l = g^{\frac{p - 1}{l}}

    直接看代码:

    #include <cstdio>
    #include <cctype>
    #define FOR(i, a, b) for (int i = a; i <= b; ++i)
    
    typedef long long ll;
    
    const ll G = 3;
    const ll mod = 998244353;
    const int maxn = 3e6 + 5;
    
    inline int read()
    {
        char c = getchar();
        int s = 0;
        while (!isdigit(c))
            c = getchar();
        while (isdigit(c))
            s = 10 * s + c - '0', c = getchar();
        return s;
    }
    
    template<typename T> inline void myswap(T &a, T &b)
    {
        T t = a;
        a = b;
        b = t;
        return;
    }
    
    ll pow(ll base, ll p = mod - 2)
    {
        ll ret = 1;
        for (; p; p >>= 1)
        {
            if (p & 1)
                ret = ret * base % mod;
            base = base * base % mod;
        }
        return ret;
    }
    
    int rev[maxn];
    ll f[maxn], g[maxn];
    const ll invG = pow(G);
    
    void NTT(ll *f, int lim, int type)
    {
        FOR(i, 0, lim - 1)
            if (i < rev[i])
                myswap(f[i], f[rev[i]]);
        for (int p = 2; p <= lim; p <<= 1)
        {
            int len = p >> 1;
            ll tG = pow(type ? G : invG, (mod - 1) / p);
            for (int k = 0; k < lim; k += p)
            {
                ll buf = 1;
                for (int l = k; l < k + len; ++l, buf = buf * tG % mod)
                {
                    ll tmp = buf * f[len + l] % mod;
                    f[len + l] = f[l] - tmp;
                    if (f[len + l] < 0) f[len + l] += mod;//及时取模
                    f[l] = f[l] + tmp;
                    if (f[l] > mod) f[l] -= mod;//及时取模
                }
            }
        }
        ll invlim = pow(lim);//最后还回去,除以lim相当于乘上lim的逆元
        if (!type)
            FOR(i, 0, lim - 1)
                f[i] = (f[i] * invlim % mod);
        return;
    }
    
    int main()
    {
        int n = read(), m = read();
        FOR(i, 0, n) f[i] = read();
        FOR(i, 0, m) g[i] = read();
        int lim = 1;
        while (lim <= n + m) lim <<= 1;
        FOR(i, 0, lim - 1)
            rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? (lim >> 1) : 0);
        NTT(f, lim, 1), NTT(g, lim, 1);
        FOR(i, 0, lim - 1)
            f[i] = f[i] * g[i] % mod;
        NTT(f, lim, 0);
        FOR(i, 0, n + m)
            printf("%d ", (int)f[i]);
        return 0;
    }
    

    FFT/NTT 优化卷积的一些例子

    在继续之前,我们先来看看 FFT/NTT 的一些应用。(高精度乘法就不说了,记得最后进位就可以了)

    • 优化一般的卷积
    • 和生成函数一起食用
    • 字符串匹配(你没看错)

    洛谷 P3338 [ZJOI2014]力

    题意:给定 {q}\{q\},定义
    Fi=j=1i1qiqj(ij)2j=i+1nqiqj(ij)2 F_i = \sum_{j = 1}^{i - 1}\frac{q_iq_j}{(i - j)^2} - \sum_{j = i + 1}^n\frac{q_iq_j}{(i - j)^2}

    Ei=Fiqi E_i=\frac{F_i}{q_i}
    考虑暴力的话,这道题是 O(n2)O(n^2) 的,过不去,考虑转化式子:
    Ei=Fiqi=j=1i1qj(ij)2j=i+1nqj(ij)2 \begin{aligned} E_i &= \frac{F_i}{q_i}\\ &=\sum_{j = 1}^{i - 1}\frac{q_j}{(i - j)^2} - \sum_{j = i + 1}^n\frac{q_j}{(i - j)^2}\\ \end{aligned}
    我们尝试将其化为卷积的形式,令 fi=qif_i = q_i,且 f0=0f_0 = 0gi=1i2g_i =\dfrac{1}{i^2},且 g0=0g_0 = 0,回代:
    Ei=j=0ifjgijj=infjgji \begin{aligned} E_i &= \sum_{j = 0}^{i}f_jg_{i - j} - \sum_{j = i}^nf_jg_{j - i} \end{aligned}
    左边的部分已经是一个卷积的形式了,考虑继续化简右边。此时我们可以使用一个翻转的技巧,令 fi=fnif'_i = f_{n - i}t=nit = n - i,则右半边的式子可以化为 j=0tftjgj\displaystyle\sum_{j = 0}^{t}f'_{t - j}g_j。现在两边都化为卷积的形式了,可以愉快的使用 FFT 加速了。

    即我们设多项式 A(x)=i=0nfixnA(x) =\displaystyle\sum_{i = 0}^nf_ix^nB(x)=i=0ngixnB(x) = \displaystyle\sum_{i = 0}^ng_ix^nC(x)=i=0nfiC(x) = \displaystyle\sum_{i = 0}^nf'_i。再令 L(x)=A(x)B(x)L(x) = A(x)B(x)R(x)=B(x)C(x)R(x) = B(x)C(x),不难发现答案 Ei=lirniE_i = l_i - r_{n - i},其中 lil_irir_i 分别为 L(x)L(x)R(x)R(x)xix^i 的系数。

    int main()
    {
        int n; scanf("%d", &n);
        FOR(i, 1, n)
        {
            scanf("%lf", &a[i].x);
            b[i].x = (1.0 / i / i);
            c[n - i].x = a[i].x;
        }
        int lim = 1;
        while (lim <= (n << 1)) lim <<= 1;
        FOR(i, 0, lim)
            rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) ? (lim >> 1) : 0));
        DFT(a, lim, 1), DFT(b, lim, 1), DFT(c, lim, 1);
        FOR(i, 0, lim)
            a[i] = a[i] * b[i], c[i] = b[i] * c[i];
        DFT(a, lim, -1), DFT(c, lim, -1);
        FOR(i, 1, n)
            printf("%.3lf\n", a[i].x - c[n - i].x);
        return 0;
    }
    

    洛谷 P3723 [AH2017/HNOI2017]礼物

    题意:给定两个序列 {x}\{x\}{y}\{y\},可以整体平移序列或者整体加/减某个数,求最终序列

    i=1n(xiyi)2 \sum_{i = 1}^n(x_i - y_i)^2

    的最小值。

    分析:设整体加减的数为 cccc 可正可负),我们需要最小化的就是下面这个式子:

    i=1n(xiyi+c)2 \sum_{i = 1}^n(x_i - y_i + c)^2

    展开上面的式子,由 (xiyi+c)2=xi2+yi2+c2+2xic2yic2xiyi(x_i - y_i +c)^2 = x_i^2 + y_i^2 + c^2 + 2x_ic - 2y_ic - 2x_iy_i 可以得到原式可化简为

    xi2+yi2+nc2+2cxi2cyi2xiyi \sum x_i^2 + \sum y_i^2 + nc^2 + 2c\sum x_i - 2c\sum y_i - 2\sum x_iy_i

    (下标省略)

    不难发现我们只需要最大化 xiyi\sum x_iy_i