diff --git "a/4-\346\225\260\345\255\246.md" "b/4-\346\225\260\345\255\246.md" index a5daefa..d2589d7 100644 --- "a/4-\346\225\260\345\255\246.md" +++ "b/4-\346\225\260\345\255\246.md" @@ -190,61 +190,143 @@ int cat(int n) ## 矩阵 ```cpp -class matrix -{ -public: - int x[105][105]; - int sz; - matrix(int n) - { - sz = n; - for (int i = 1; i <= sz; i++) - { - for (int j = 1; j <= sz; j++) - { - x[i][j] = 0; +struct matrix{ + ll n; + vector> M; + matrix(ll nn){ + n=nn; + M.resize(n+1,vector(n+1)); + clear(); + } + void clear(){ + rep(i,0,n){ + rep(j,0,n){ + M[i][j]=0; } } } - matrix mul(matrix a, matrix b); - matrix qpow(matrix a, int n); - void tra(matrix a); -}; - -matrix matrix::mul(matrix a, matrix b) -{ - matrix c(a.sz); - for (int i = 1; i <= a.sz; i++) - for (int j = 1; j <= a.sz; j++) - for (int k = 1; k <= a.sz; k++) - c.x[i][j] = (c.x[i][j] % mod + (a.x[i][k] * b.x[k][j]) % mod) % mod; - return c; -} -matrix matrix::qpow(matrix a, int n) -{ - matrix res(a.sz); - for (int i = 1; i <= a.sz; i++) - res.x[i][i] = 1; - while (n > 0) - { - if (n & 1) - res = mul(res, a); - a = mul(a, a); - n >>= 1; + void reset(){ + clear(); + rep(i,0,n){ + M[i][i]=1; + } } - return res; -} -void matrix::tra(matrix a) -{ - for (int i = 1; i <= a.sz; i++) - { - for (int j = 1; j <= a.sz; j++) - { - cout << a.x[i][j] << " "; + matrix operator+(matrix t){ + matrix ans(n); + rep(i,1,n){ + rep(j,1,n){ + ans.M[i][j]=(M[i][j]+t.M[i][j])%mod; + } } - cout << endl; + return ans; } -} + matrix operator-(matrix t){ + matrix ans(n); + rep(i,1,n){ + rep(j,1,n){ + ans.M[i][j]=(M[i][j]-t.M[i][j]+mod)%mod; + } + } + return ans; + } + matrix operator*(matrix t){ + matrix ans(n); + rep(i,1,n){ + rep(j,1,n){ + rep(k,1,n){ + ans.M[i][j]+=M[i][k]*t.M[k][j]; + ans.M[i][j]%=mod; + } + } + } + return ans; + } + matrix inv(bool &ret){//只做初等行变换是不会影响结果的 + ll m=n*2; + matrix a(m); + rep(i,1,n){ + rep(j,1,n){ + a.M[i][j]=M[i][j]; + } + a.M[i][i+n]=1; + } + matrix ans(n); + rep(i,1,n){ + ll pos=i; + rep(j,i+1,n){ + if(llabs(a.M[j][i])>llabs(a.M[pos][i])){ + pos=j; + } + } + //找最大防止这里是0 + if(i!=pos){ + swap(a.M[i],a.M[pos]); + } + if(a.M[i][i]==0){ + ret=false; + return ans; + } + ll inv=fast(a.M[i][i],mod-2); + rep(j,i,m){ + a.M[i][j]=a.M[i][j]*inv%mod; + } + rep(j,1,n){ + if(j==i) continue; + ll b=a.M[j][i]; + rep(k,i,m){ + a.M[j][k]=(a.M[j][k]-b*a.M[i][k]%mod+mod)%mod; + } + } + } + rep(i,1,n){ + rep(j,1,n){ + ans.M[i][j]=a.M[i][j+n]; + } + } + return ans; + } + void print(){ + rep(i,1,n){ + rep(j,1,n){ + cout<