QLITRE DIALY

UnionFindをつい使ってしまう

2023年02月23日

「興味深いことに、金槌しか道具を持っていない人は、何もかも釘であるかのように取り扱う」

— アブラハム・マズロー

有名なゴールデン・ハンマーに関することわざだ。

自分は競技プログラミングをやっているときに、このことわざを思い出すことが多い。

ゼロからコードを書くこともあるが、ネットで拾った強力なアルゴリズムを使用して問題をぶっ叩く時もある。

そういう時に魔法のゴールデン・ハンマーを使っているような気分になる。

実際に強力なツールをうまく使えた時はとても楽しい。競技プログラミングの楽しさの一つはここにあるとさえ考えている。

もちろん全ての問題を解決できるゴールデン・ハンマーは存在しないが、限りなくそれに近いと思っているアルゴリズムがある。

UnionFindだ。

Union-Find はグループ分けを効率的に管理する、根付き木を用いたデータ構造です。

アルゴ式

といわけで、今回はUnionFindを使って問題をぶっ叩いた例をいくつか紹介する。

B - レ

https://atcoder.jp/contests/abc289/tasks/abc289_b

Bレベルなのに異様にてこずった問題。漢文のレ点の読み方をシミュレートする問題だ。

自分はこの問題を素朴な幅優先探索で苦労しながら解いた。しかしUnionFindを使うと簡単に考えられる。

レ点でつながっている箇所をUnionFindでつなげればいい。

from collections import defaultdict


class UnionFindSimple:
    """要素に整数を使用するUnionFind"""

    def __init__(self, n):
        self.n = n
        self.parents = [-1] * n

    def find(self, x):
        if self.parents[x] < 0:
            return x
        else:
            self.parents[x] = self.find(self.parents[x])
            return self.parents[x]

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)

        if x == y:
            return

        if self.parents[x] > self.parents[y]:
            x, y = y, x

        self.parents[x] += self.parents[y]
        self.parents[y] = x

    def size(self, x):
        return -self.parents[self.find(x)]

    def same(self, x, y):
        return self.find(x) == self.find(y)

    def members(self, x):
        root = self.find(x)
        return [i for i in range(self.n) if self.find(i) == root]

    def roots(self):
        return [i for i, x in enumerate(self.parents) if x < 0]

    def group_count(self):
        return len(self.roots())

    def all_group_members(self):
        group_members = defaultdict(list)
        for member in range(self.n):
            group_members[self.find(member)].append(member)
        return group_members

    def __str__(self):
        return '\n'.join(f'{r}: {m}' for r, m in self.all_group_members().items())


def solve():
    n, m = map(int, input().split())
    uf = UnionFindSimple(n + 1)
    a_list = list(map(int, input().split()))
    for num in a_list:
        # レ点で繋がっている数字をつなげる
        uf.union(num, num + 1)

    ans = []
    for _, numbers in uf.all_group_members().items():
        # 繋がっている数字を降順に並び替える
        numbers.sort(reverse=True)
        for num in numbers:
            ans.append(num)
    # 0始まりのため
    print(*ans[1:])


if __name__ == '__main__':
    solve()

D - Change Usernames

https://atcoder.jp/contests/abc285/tasks/abc285_d

N人のユーザーがいて、元のユーザー名と変更を希望しているユーザー名がある。

ユーザー名を変更する順番は好きに決めていいが、既に存在しているユーザー名に変更をすることができない、というルールがある。

順番を適切に管理することで、すべてのユーザーの希望をかなえられるか判定する問題。

この問題も苦労しながら幅優先探索で解いた。パターンを考えると以下のようにユーザー名が循環しているときに、希望がかなえられないことがわかる。

A → B
B → C
C → A

こういう循環を判定するのは素でやると結構難しい。しかしUnionFindを使うと簡単だ。

from collections import defaultdict


class UnionFindMultiple:
    def __init__(self):
        """
        unionfind経路圧縮あり,要素にtupleや文字列可,始めに要素数指定なし
        """
        self.parents = dict()  # {子要素:親ID,}
        self.members_set = defaultdict(lambda: set())  # keyが根でvalueが根に属する要素要素(tupleや文字列可)
        self.roots_set = set()  # 根の集合(tupleや文字列可)
        self.key_ID = dict()  # 各要素にIDを割り振る
        self.ID_key = dict()  # IDから要素名を復元する
        self.cnt = 0  # IDのカウンター

    def dict_find(self, x):  # 要素名とIDをやり取りするところ
        if x in self.key_ID:
            return self.key_ID[x]
        else:
            self.cnt += 1
            self.key_ID[x] = self.cnt
            self.parents[x] = self.cnt
            self.ID_key[self.cnt] = x
            self.members_set[x].add(x)
            self.roots_set.add(x)
            return self.key_ID[x]

    def find(self, x):
        id_x = self.dict_find(x)
        if self.parents[x] == id_x:
            return x
        else:
            self.parents[x] = self.key_ID[self.find(self.ID_key[self.parents[x]])]
            return self.ID_key[self.parents[x]]

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)
        if self.parents[x] > self.parents[y]:
            x, y = y, x
        if x == y:
            return
        for i in self.members_set[y]:
            self.members_set[x].add(i)
        self.members_set[y] = set()
        self.roots_set.remove(y)
        self.parents[y] = self.key_ID[x]

    def size(self, x):  # xが含まれる集合の要素数
        return len(self.members_set[self.find(x)])

    def same(self, x, y):  # 同じ集合に属するかの判定
        return self.find(x) == self.find(y)

    def members(self, x):  # xを含む集合の要素
        return self.members_set[self.find(x)]

    def roots(self):  # 根の要素
        return self.roots_set

    def group_count(self):  # 根の数
        return len(self.roots_set)

    def all_group_members(self):  # 根とその要素
        return {r: self.members_set[r] for r in self.roots_set}


def job():
    n = int(input())
    uni = UnionFindMultiple()
    for _ in range(n):
        a, b = map(str, input().split())
        # a,bが同じグループだった場合、循環している
        if uni.same(a, b):
            exit(print('No'))
        uni.union(a, b)
    # 循環が無ければ可能
    print('Yes')


job()

D - Takahashi's Solitaire

https://atcoder.jp/contests/abc277/tasks/abc277_d

N枚のカードを持っている状態で、一枚ずつ場にカードを出していく。最後に出したカードと同じか、(カードの数字+1) MOD Mが書かれたカードがあれば場に出すことができる。そのようなルールでゲームを行い、残ったカードの総和の最小値を求める問題。

これは初見では解くことができなかった。

解説では深さ優先探索を用いた実装例が紹介されていたが、UnionFindを使って殴ることもできる。

つまり出せるカードの組み合わせをグルーピングしておけば答えが出せそうだ。

その際、グループ化されていないカードも考慮することが必要。

例えば100を1枚、同じグループの5,6,7が手札としてあった場合、一回しか場に出せないが、100を出した方がスコアは良くなる。

from collections import defaultdict


class UnionFindMultiple:
    def __init__(self):
        """
        unionfind経路圧縮あり,要素にtupleや文字列可,始めに要素数指定なし
        """
        self.parents = dict()  # {子要素:親ID,}
        self.members_set = defaultdict(lambda: set())  # keyが根でvalueが根に属する要素要素(tupleや文字列可)
        self.roots_set = set()  # 根の集合(tupleや文字列可)
        self.key_ID = dict()  # 各要素にIDを割り振る
        self.ID_key = dict()  # IDから要素名を復元する
        self.cnt = 0  # IDのカウンター

    def dict_find(self, x):  # 要素名とIDをやり取りするところ
        if x in self.key_ID:
            return self.key_ID[x]
        else:
            self.cnt += 1
            self.key_ID[x] = self.cnt
            self.parents[x] = self.cnt
            self.ID_key[self.cnt] = x
            self.members_set[x].add(x)
            self.roots_set.add(x)
            return self.key_ID[x]

    def find(self, x):
        id_x = self.dict_find(x)
        if self.parents[x] == id_x:
            return x
        else:
            self.parents[x] = self.key_ID[self.find(self.ID_key[self.parents[x]])]
            return self.ID_key[self.parents[x]]

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)
        if self.parents[x] > self.parents[y]:
            x, y = y, x
        if x == y:
            return
        for i in self.members_set[y]:
            self.members_set[x].add(i)
        self.members_set[y] = set()
        self.roots_set.remove(y)
        self.parents[y] = self.key_ID[x]

    def size(self, x):  # xが含まれる集合の要素数
        return len(self.members_set[self.find(x)])

    def same(self, x, y):  # 同じ集合に属するかの判定
        return self.find(x) == self.find(y)

    def members(self, x):  # xを含む集合の要素
        return self.members_set[self.find(x)]

    def roots(self):  # 根の要素
        return self.roots_set

    def group_count(self):  # 根の数
        return len(self.roots_set)

    def all_group_members(self):  # 根とその要素
        return {r: self.members_set[r] for r in self.roots_set}


def solve():
    n, m = map(int, input().split())
    cards = list(map(int, input().split()))
    card_count = defaultdict(int)
    # 書かれている数字ごとに枚数を数える
    for num in cards:
        card_count[num] += 1
    # 1種類しか数字がなければ0にできる
    if len(card_count) == 1:
        exit(print(0))
    # 数字ごとにグループを作る
    uni = UnionFindMultiple()
    for num, cnt in card_count.items():
        a_mod = (num + 1) % m
        # +1 mod mがあればグループにできる
        if card_count.get(a_mod):
            uni.union(num, a_mod)

    # 取り除ける総和の最大を調べる
    ans = -10 ** 18
    # 普通に同じ数字のみを取り除いた場合
    for num, cnt in card_count.items():
        ans = max(ans, num * cnt)
    # グループの数字を取り除いた場合
    for _, numbers in uni.all_group_members().items():
        ans_tmp = 0
        for num in numbers:
            cnt = card_count[num]
            ans_tmp += num * cnt
        ans = max(ans, ans_tmp)

    total = sum(cards)
    print(total - ans)


solve()

D - Friends

https://atcoder.jp/contests/abc177/tasks/abc177_d

友達の情報が与えられて、すべてのグループで友達同士がいないという状況を作るために、何グループ必要か?ということを解く問題。

これもUnionFindを使わなければ難しい。

考えると、一番人数が多いグループの人数だけグループを作ることが最適なことが分かる。

例えば以下のように3人と2人のグループがあるとする。

A,B,C
D,E

こういう状況の時にBとCを別のグループにする。他のグループの人間は新しくできたグループに突っ込めば、友達同士にならない。

UnionFindしてグループ数の最大を出力すればAC。

from collections import defaultdict


class UnionFindSimple:
    """要素に整数を使用するUnionFind"""

    def __init__(self, n):
        self.n = n
        self.parents = [-1] * n

    def find(self, x):
        if self.parents[x] < 0:
            return x
        else:
            self.parents[x] = self.find(self.parents[x])
            return self.parents[x]

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)

        if x == y:
            return

        if self.parents[x] > self.parents[y]:
            x, y = y, x

        self.parents[x] += self.parents[y]
        self.parents[y] = x

    def size(self, x):
        return -self.parents[self.find(x)]

    def same(self, x, y):
        return self.find(x) == self.find(y)

    def members(self, x):
        root = self.find(x)
        return [i for i in range(self.n) if self.find(i) == root]

    def roots(self):
        return [i for i, x in enumerate(self.parents) if x < 0]

    def group_count(self):
        return len(self.roots())

    def all_group_members(self):
        group_members = defaultdict(list)
        for member in range(self.n):
            group_members[self.find(member)].append(member)
        return group_members

    def __str__(self):
        return '\n'.join(f'{r}: {m}' for r, m in self.all_group_members().items())


def solve():
    n, m = map(int, input().split())
    uni = UnionFindSimple(n)
    for _ in range(m):
        a, b = map(int, input().split())
        uni.union(a - 1, b - 1)
    ans = -10 ** 18
    for _, members in uni.all_group_members().items():
        ans = max(ans, len(members))

    print(ans)


solve()

C - K Swap

https://atcoder.jp/contests/abc254/tasks/abc254_c

数列とkが与えられて、k個先の数字をswapさせることで、数列を昇順に並び替えることができるか、という問題。

例えば9個の数列があって、kが2の場合。

1,3,5,7,9

2,4,6,8

番目の数字を入れ替えることで昇順にできるか調べる。

自分はfor 文の添え字を管理することが苦手で、UnionFindを使って考えた方がしっくりくる。

つまり、添え字をグルーピングして後から答えを構成してしまえばいい。

from collections import defaultdict


class UnionFindSimple:
    """要素に整数を使用するUnionFind"""

    def __init__(self, n):
        self.n = n
        self.parents = [-1] * n

    def find(self, x):
        if self.parents[x] < 0:
            return x
        else:
            self.parents[x] = self.find(self.parents[x])
            return self.parents[x]

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)

        if x == y:
            return

        if self.parents[x] > self.parents[y]:
            x, y = y, x

        self.parents[x] += self.parents[y]
        self.parents[y] = x

    def size(self, x):
        return -self.parents[self.find(x)]

    def same(self, x, y):
        return self.find(x) == self.find(y)

    def members(self, x):
        root = self.find(x)
        return [i for i in range(self.n) if self.find(i) == root]

    def roots(self):
        return [i for i, x in enumerate(self.parents) if x < 0]

    def group_count(self):
        return len(self.roots())

    def all_group_members(self):
        group_members = defaultdict(list)
        for member in range(self.n):
            group_members[self.find(member)].append(member)
        return group_members

    def __str__(self):
        return '\n'.join(f'{r}: {m}' for r, m in self.all_group_members().items())


def solve():
    n, k = map(int, input().split())
    uni = UnionFindSimple(n)
    a_list = list(map(int, input().split()))

    for i in range(n - k):
        # 今のインデックスとk個先のindexをグルーピングする
        uni.union(i, i + k)

    ans = [-1] * n
    for _, index_list in uni.all_group_members().items():
        # indexメンバーの配列を作る
        tmp = []
        for i in index_list:
            tmp.append(a_list[i])
        # 昇順に並び替える
        tmp.sort()
        tmp_index = 0
        # 答えにあてはめる
        for i in index_list:
            ans[i] = tmp[tmp_index]
            tmp_index += 1
    # 元の配列をソート
    a_list.sort()
    # 等しかったらyes
    if a_list == ans:
        print('Yes')
    else:
        print('No')


solve()

おわりに

アンチパターンかもしれないが、つながりや循環が関連する問題ではついついUnionFindのことを考えてしまう。

これでうまくいかないときもあるが、うまくいくこともある。

ゴールデン・ハンマーを使って釘を叩くのはとても楽しくて病みつきになってしまう。ではでは。