← Browse
Graphs

Union-Find

TimeO(a(n)) SpaceO(n)

Imagine a social network where friend groups merge. Each group has a representative. When two people become friends, their groups merge by having one representative point to the other. To check if two people are in the same group, follow the chain of representatives to the top — if they end at the same person, they’re connected.

When to use it: Dynamic connectivity queries (“are these two nodes connected?”), detecting cycles in undirected graphs, Kruskal’s minimum spanning tree, or grouping elements that can be merged over time.

Key insight: With path compression (flatten the tree on every find) and union by rank (attach shorter tree under taller), both operations run in nearly O(1) amortized — specifically O(alpha(n)), where alpha is the inverse Ackermann function, practically constant.

Common trap: Forgetting path compression. Without it, the tree can become a long chain, degrading find to O(n). Always compress paths during find.

class UnionFind:
    parent = [0, 1, 2, ..., n-1]     # each node is its own parent
    rank = [0, 0, 0, ..., 0]          # rank (approximate tree height)

function find(x):
    if parent[x] != x:
        parent[x] = find(parent[x])   # path compression
    return parent[x]

function union(x, y):
    rootX = find(x)
    rootY = find(y)

    if rootX == rootY:
        return false                   # already in same set

    # Union by rank: attach shorter tree under taller
    if rank[rootX] < rank[rootY]:
        parent[rootX] = rootY
    else if rank[rootX] > rank[rootY]:
        parent[rootY] = rootX
    else:
        parent[rootY] = rootX
        rank[rootX] = rank[rootX] + 1

    return true

function connected(x, y):
    return find(x) == find(y)
class UnionFind:
    def __init__(self, n: int):
        self.parent = list(range(n))
        self.rank = [0] * n
        self.count = n  # number of disjoint sets

    def find(self, x: int) -> int:
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])  # path compression
        return self.parent[x]

    def union(self, x: int, y: int) -> bool:
        root_x, root_y = self.find(x), self.find(y)
        if root_x == root_y:
            return False

        # Union by rank
        if self.rank[root_x] < self.rank[root_y]:
            self.parent[root_x] = root_y
        elif self.rank[root_x] > self.rank[root_y]:
            self.parent[root_y] = root_x
        else:
            self.parent[root_y] = root_x
            self.rank[root_x] += 1

        self.count -= 1
        return True

    def connected(self, x: int, y: int) -> bool:
        return self.find(x) == self.find(y)


def num_connected_components(n: int, edges: list[list[int]]) -> int:
    """Count connected components using Union-Find."""
    uf = UnionFind(n)
    for u, v in edges:
        uf.union(u, v)
    return uf.count


def has_cycle_undirected(n: int, edges: list[list[int]]) -> bool:
    """Detect cycle in an undirected graph using Union-Find."""
    uf = UnionFind(n)
    for u, v in edges:
        if uf.connected(u, v):
            return True  # edge connects two nodes already in the same set
        uf.union(u, v)
    return False

Problem: Given n nodes and a list of edges added one at a time, determine the earliest point at which all nodes become connected.

How would you approach this?

Answer: This is Union-Find. Process edges in order, unioning the two endpoints each time. After each union, check if the number of components has dropped to 1. The edge that triggers this is the answer.


Problem: Given an undirected graph with n nodes and a list of edges, determine the number of connected components.

How would you approach this?

Answer: This is Union-Find. Initialize n components (one per node). For each edge, union the two endpoints. The final value of count is the number of connected components. Alternatively use DFS/BFS, but Union-Find handles dynamic edge additions better.

Time: O(alpha(n)) per operation where alpha is the inverse Ackermann function. For all practical input sizes, alpha(n) <= 5, making each operation effectively O(1). Over m operations, total time is O(m * alpha(n)).

Space: O(n) for the parent and rank arrays, one entry per element.