memo46

競プロの精進記録その他。

CODE THANKS FESTIVAL 2017(Parallel) F - Limited Xor Subset

別解が面白かったのでメモを残します。

問題

N要素の正の整数列aがあり、i番目の要素はa_{i}である。 N個の整数のうち0個以上を選んでそれらのbitごとの総XORを計算したとき、それがKとなるような整数の選び方の総数を10^{9}+7で割った余りを求めなさい。 ただし、0個選んだときのbitごとのXORは0とする。

制約

  • 1<=N<=100000
  • 0<=K<=100000
  • 1<=a_{i} (1<=i<=N)
  • \sum_{1}^{N} {a_i} <= 100000
  • 入力はすべて整数

解法

  • 想定解(https://img.atcoder.jp/code-thanks-festival-2017-open/editorial.pdf)はdpですが、ここでは別解のみを書きます。
  • まず、問題文を読むと、この問題の性質より、①数列の要素をswapしても答えは変わらない。②数列の2つの要素a_ia_jに対してXORを取り、それを新たな数列の要素(つまり、a_i\oplus{a_j}a_jに新たに変更する)としても答えは変わらない。ことがわかります。まず、①について、XORは可換であるためこれは成立します。②については、以下の説明により成り立ちます。 f:id:bakamono1357:20191118145958p:plain
  • ここで、数列aKをそれぞれmod2上の行列、行ベクトルで表すことを考えます。
  • まず、数列をbitごとに分解します。各桁を行、数列の個数を列とし、行列にして表します。例えば、N=4a={1,2,5,7}の時、以下のように表します。この場合、i行が下からi桁目のbit、j列がj番目の数列の要素に対応していることがわかります。 f:id:bakamono1357:20191118143759p:plain
  • 次に、Kを各bitに分解し、行ベクトルに分解します。例えばK=3の場合、以下のように表します。 f:id:bakamono1357:20191118144201p:plain
  • ここで、実は、先程説明した①は行の入れ替え、②はある行に別の行を加えるに対応していることがわかります。つまり、これは行基本変形と同じであることがわかります。
  • ここから考えると、求める選び方の総数は連立1次方程式Ax=bの解の個数に等しいです。よって答えは、行基本変形により掃き出し法でrankを求め、解が存在するならば2^{N-rank}(なぜなら、パラメータの個数はN-rank個あり、各パラメータの値はmod2なので2通りしかないため)存在しないならば0通りであることがわかりました。計算量は行数をR、列数をCとすると、bitset高速化により、O(R^{2}C/64)となります。
  • 参考文献:けんちょんさんの記事(http://drken1215.hatenablog.com/entry/2019/03/20/202800)、trapのブログ(https://trap.jp/post/435/

実装

#include <bits/stdc++.h>
using namespace std;
template <class T>
inline bool chmax(T &a, T b)
{
    if (a < b)
    {
        a = b;
        return 1;
    }
    return 0;
}
template <class T>
inline bool chmin(T &a, T b)
{
    if (a > b)
    {
        a = b;
        return 1;
    }
    return 0;
}
typedef long long int ll;

#define ALL(v) (v).begin(), (v).end()
#define RALL(v) (v).rbegin(), (v).rend()
#define endl "\n"
const double EPS = 1e-7;
const int INF = 1 << 30;
const ll LLINF = 1LL << 60;
const double PI = acos(-1);
const int MOD = 1000000007;
const int dx[4] = {1, 0, -1, 0};
const int dy[4] = {0, 1, 0, -1};

//-------------------------------------

const int MAX_ROW = 20;     // 行
const int MAX_COL = 110000; // 列

struct BitMatrix
{
    int h, w;
    bitset<MAX_COL> val[MAX_ROW];
    BitMatrix(int m = 1, int n = 1) : h(m), w(n) {}
    bitset<MAX_COL> &operator[](int i) { return val[i]; }
};

// 掃き出し法によりrankを求める
int GaussJordan(BitMatrix &A, bool isExtended = false)
{
    int rank = 0;
    for (int c = 0; c < A.w; c++)
    {
        if (isExtended && c == A.w - 1)
        {
            break;
        }
        int p = -1;
        for (int r = rank; r < A.h; r++)
        {
            if (A[r][c])
            {
                p = r;
                break;
            }
        }
        if (p == -1)
        {
            continue;
        }
        swap(A[p], A[rank]);
        for (int r = 0; r < A.h; r++)
        {
            if (r != rank && A[r][c])
            {
                A[r] ^= A[rank];
            }
        }
        rank++;
    }
    return rank;
}

int linearEquation(BitMatrix A, vector<int> b, vector<int> &res)
{
    int m = A.h;
    int n = A.w;
    BitMatrix mat(m, n + 1);
    for (int i = 0; i < m; i++)
    {
        for (int j = 0; j < n; j++)
        {
            mat[i][j] = A[i][j];
        }
        mat[i][n] = b[i];
    }
    int rank = GaussJordan(mat, true);
    // 解なしの場合
    for (int r = rank; r < m; r++)
    {
        if (mat[r][n])
        {
            return -1;
        }
    }
    // 解がある場合、resに解を代入し、rankを返す
    res.assign(n, 0);
    for (int i = 0; i < rank; i++)
    {
        res[i] = mat[i][n];
    }
    return rank;
}

ll pow_mod(ll n, ll k, ll mod)
{
    if (k == 0)
    {
        return 1;
    }
    else if (k % 2 == 1)
    {
        return pow_mod(n, k - 1, mod) * n % mod;
    }
    else
    {
        ll t = pow_mod(n, k / 2, mod);
        return t * t % mod;
    }
}

int main()
{
    cin.tie(0);
    ios::sync_with_stdio(false);
    int n, k;
    cin >> n >> k;
    vector<int> a(n);
    for (int i = 0; i < n; i++)
    {
        cin >> a[i];
    }
    BitMatrix A(20, n);
    vector<int> b(20);
    for (int d = 0; d < 20; d++)
    {
        for (int i = 0; i < n; i++)
        {
            if (a[i] & (1 << d))
            {
                A[d][i] = 1;
            }
        }
        if (k & (1 << d))
        {
            b[d] = 1;
        }
    }
    vector<int> ans;
    int rank = linearEquation(A, b, ans);
    cout << (rank == -1 ? 0 : pow_mod(2LL, n - rank, MOD)) << endl;
}

感想

  • 俗に言う「F2上で線形代数」ってやつだそうです。実装はけんちょんさんの記事を大いに参考(というか写経)にしました。
  • この解法ならばa_iの制約がもっと大きくても解けるので、すごい...