1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
| #include<bits/stdc++.h> using namespace std; typedef long long ll; #define pb push_back #define Mod1(x) ((x>=P)&&(x-=P)) #define Mod2(x) ((x<0)&&(x+=P)) #define rep(i,a,b) for(int i=a,i##end=b;i<=i##end;++i) #define drep(i,a,b) for(int i=a,i##end=b;i>=i##end;--i) char IO; int rd(){ int s=0; while(!isdigit(IO=getchar())); do s=(s<<1)+(s<<3)+(IO^'0'); while(isdigit(IO=getchar())); return s; }
const int N=1<<18|10,P=998244353;
int n,k; ll qpow(ll x,ll k=P-2) { ll res=1; for(;k;k>>=1,x=x*x%P) if(k&1) res=res*x%P; return res; }
int rev[N],I[N],J[N]; typedef vector <int> V; void Init(){ rep(i,J[0]=1,N-1) J[i]=1ll*J[i-1]*i%P; I[N-1]=qpow(J[N-1]); drep(i,N-1,1) I[i-1]=1ll*I[i]*i%P; } int Init(int n){ int R=1,c=-1; while(R<=n) R<<=1,c++; rep(i,0,R-1) rev[i]=(rev[i>>1]>>1)|((i&1)<<c); return R; } void NTT(int n,V &a,int f) { static int e[N>>1]; rep(i,0,n-1) if(i<rev[i]) swap(a[i],a[rev[i]]); for(int i=e[0]=1;i<n;i<<=1) { ll t=qpow(f==1?3:(P+1)/3,(P-1)/i/2); for(int j=i-2;j>=0;j-=2) e[j+1]=(e[j]=e[j>>1])*t%P; for(int l=0;l<n;l+=i*2) { for(int j=l;j<l+i;++j) { int t=1ll*e[j-l]*a[j+i]%P; a[j+i]=a[j]-t,Mod2(a[j+i]); a[j]+=t,Mod1(a[j]); } } } if(f==-1) { ll Inv=1ll*I[n]*J[n-1]%P; rep(i,0,n-1) a[i]=a[i]*Inv%P; } }
V operator * (V a,V b){ if(!a.size() || !b.size()) return {}; int n=a.size()+b.size()-1,R=Init(n); a.resize(R),b.resize(R); NTT(R,a,1),NTT(R,b,1); rep(i,0,R-1) a[i]=1ll*a[i]*b[i]%P; NTT(R,a,-1); a.resize(n); return a; } V operator + (V a,V b){ if(a.size()<b.size()) swap(a,b); rep(i,0,b.size()-1) a[i]+=b[i],Mod1(a[i]); return a; } V operator - (V a,const V &b){ if(a.size()<b.size()) a.resize(b.size()); rep(i,0,b.size()-1) a[i]-=b[i],Mod2(a[i]); return a; }
V operator << (V a,const int &x){ a.resize(a.size()+x); drep(i,a.size()-1,x) a[i]=a[i-x]; rep(i,0,x-1) a[i]=0; return a; }
int C(int n,int m){ return n<0||m<0||n<m?0:1ll*J[n]*I[m]%P*I[n-m]%P; } V Binom(int n){ V A(n+1); rep(i,0,n) A[i]=(n-i)&1?P-C(n,i):C(n,i); return A; } V Solve(int n,int l,int r){ if(l==r) return {C(n-l*k,l)}; int mid=(l+r)>>1; return Solve(n,l,mid)+((Solve(n,mid+1,r)*Binom(mid-l+1))<<(k*(mid-l+1))); } V GetG(int n){ return Solve(n,0,n/(k+1)); } V GetF(int n){ return GetG(n)-(GetG(n-k)<<k); }
vector <V> T; V Solve(int l=0,int r=T.size()-1){ if(l==r) return T[l]; int mid=(l+r)>>1; return Solve(l,mid)*Solve(mid+1,r); } int A[N];
int main(){ Init(),n=rd(),k=rd(); rep(i,1,n) A[i]=rd(); sort(A+1,A+n+1); rep(i,1,n) { int j=i; while(A[j+1]==A[j]+1) j++; if(j-i+1>=k) T.pb(GetF(j-i+1)); i=j; } V Res=Solve(); int s=0,ans=0; rep(i,1,Res.size()-1){ s=(s+1ll*n*I[i]%P*J[i-1])%P; ans=(ans+1ll*s*Res[i])%P; } ans=(P-ans)%P; printf("%d\n",ans); }
|