[Contest][AOJ][C++] 高速ゼータ変換 / 高速メビウス変換

The Enemy of My Enemy is My Friend | Aizu Online Judgeを高速メビウス変換を使って解きました。その際に学んだことの備忘録としてこの記事を残します。

問題概要

N(N <= 40)個の国がある。国はそれぞれ軍事力B(B <= 1000)を持っている。
以下の条件を満たすように同盟を組むとき、1つめの国を含む同盟の軍事力の和の最大値を求めよ。

  • 自国の隣国とは同盟を結ぶことは出来ない。
  • 同盟をした国の隣国とは同盟を結ぶことは出来ない。

解法

半分全列挙 + 高速メビウス変換(厳密には違う?)

1. 国を20個ずつの2つの集合に分け、それぞれビットマスクmaskで表現する
(maskはnビット目が1→n番目の国を含む、0→n番目の国を含まないに対応)

2. 前半の集合(1つ目の国を含む集合)、後半の集合(1つ目の国を含まない集合)それぞれについて
dp[mask] = (maskに含まれる国で同盟を結ぶことができる→軍事力の和、できない→0)
を計算する

3. 高速メビウス変換で後半の集合について、
dp2[mask] = max_{(s \subset mask)}dp[s]
を計算する

4. 前半の集合の一つ一つにおいて、
dp[mask] + dp2[(maskに含まれる国に隣合わない後半の国の集合のビットマスク)]を計算する

4の最大値が答えになる。
(ただしこれは想定解ではなく、本来は単純に枝刈りで解けるらしい。。)

高速メビウス変換を使っているところ

全ての部分集合sに対して
dp[s] = max_{s \subset u} dp[u]
を高速に求めている(O(N/2 × 2^(N/2)))。

  // 高速メビウス変換でビットマスクsの部分集合から最大値を出す
  for(int i=0; i<size; i++) {
    for(int s=0; s<(1<<size); s++) {
      if ((s>>i&1)==1) {
        dp[dp_index][s]= max(dp[dp_index][s^(1<<i)], dp[dp_index][s]);
      }
    }
  }

計算量

O(N/2 × 2^(N/2))

ソースコード

#include <algorithm>
#include <cassert>
#include <climits>
#include <cmath>
#include <iostream>
#include <map>
#include <queue>
#include <string>
#include <vector>

using namespace std;

typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
typedef vector<int> VI;
typedef vector<ll> VL;
typedef pair<int, int> ipair;
typedef tuple<int, int, int> ituple;

#define PI acosl(-1)
#define MAX_N (40 + 2)
#define SIZE_PER_ONCE 20

const ull ONE = 1;

// 国の集合の中の総和の最大
ull dp[2][1 << SIZE_PER_ONCE];

map<string, int> name2index;
vector<string> tmp_neighbor[MAX_N];
int power[MAX_N];
ull neighbor[MAX_N];

void set_dp(int dp_index, int start_index, int size) {
  // ビットマスクiに含まれる国の軍事力の和を求める
  for (int i = 1; i < (1 << size); i++){
    ull b = 0;
    ull score = 0;
    for (int j = 0; j < size; j++){
      if ((i & (1 << j)) == 0) {
        continue;
      }
      if ((b & neighbor[start_index + j]) != 0) {
        score = 0;
        break;
      }
      score += power[start_index + j];
      b |= ONE << (start_index + j);
    }
    dp[dp_index][i] = score;
  }

  if (start_index == 0) {
    return;
  }

  // 高速メビウス変換でビットマスクsの部分集合から最大値を出す
  for(int i=0; i<size; i++) {
    for(int s=0; s<(1<<size); s++) {
      if ((s>>i&1)==1) {
        dp[dp_index][s]= max(dp[dp_index][s^(1<<i)], dp[dp_index][s]);
      }
    }
  }

}

void init() {
  for (int i = 0; i < MAX_N; i++){
      tmp_neighbor[i].clear();
  }

  name2index.clear();
}

void exec(int n){
  string a;
  int b, c;
  string s;

  init();

  for (int i = 0; i < n; i++){
    cin >> a;
    scanf("%d%d", &b, &c);
    name2index[a] = i;
    power[i] = b;

    for (int j = 0; j < c; j++){
      cin >> s;
      tmp_neighbor[i].push_back(s);
    }
  }

  // 隣接する国のビット表現を作る
  for (int i = 0; i < n; i++){
    ull tmp = 0;
    tmp |= ONE << i;

    for (int j = 0; j < tmp_neighbor[i].size(); j++){
      tmp |= ONE << name2index[tmp_neighbor[i][j]];
    }

    neighbor[i] = tmp;
  }

  ull ans = 0;
  if (n <= SIZE_PER_ONCE) {
    // nが小さい場合は分割しない
    set_dp(0, 0, n);

    for (int i = 0; i < (ONE << n); i++){
      if ((i & 1) == 1) {
        ans = max(ans, dp[0][i]);
      }
    }
  }
  else {
    // 前半分を計算
    set_dp(0, 0, SIZE_PER_ONCE);
    // 後ろ半分を計算
    set_dp(1, SIZE_PER_ONCE, n - SIZE_PER_ONCE);

    for (int i = 0; i < (ONE << SIZE_PER_ONCE); i++){
      if ((i & 1) == 1) {
        ull tmp = dp[0][i];

        // 前半分の集合から、同盟を組める後ろ半分の集合を求める
        ull tmp_mask = 0, mask = 0;
        for (int j = 0; j < SIZE_PER_ONCE; j++){
          if (((i >> j) & 1) == 1) {
            tmp_mask |= neighbor[j];
          }
        }
        tmp_mask >>= SIZE_PER_ONCE;

        for (int j = 0; j < n - SIZE_PER_ONCE; j++){
          if (((tmp_mask >> j) & 1) == 0) {
            mask |= ONE << j;
          }
        }

        // 後ろ半分の集合の軍事力の最大値を加算
        tmp += dp[1][mask];

        ans = max(tmp, ans);
      }
    }
  }

  cout << ans << endl;
}

void solve(){
  int t = 1;
  while (scanf("%d", &t) != EOF) {
    if (t == 0) {
      break;
    }
    exec(t);
  }
}

int main(){
  solve();
  return 0;
}