Union Find (Disjoint Set)

Master the art of grouping elements and detecting connections efficiently

Union Find Pattern

Union Find (also known as Disjoint Set) is a data structure that keeps track of elements split into one or more disjoint sets. It provides near-constant-time operations to unite sets and determine if elements are in the same set.

Understanding Through Examples

  1. Friend Circles
    """
    Problem: Given friendships between people, find friend circles
    
    Example:
    People: A, B, C, D, E
    Friendships: A-B, B-C, D-E
    
    Visual representation:
    A──B──C    D──E
    
    Friend circles: 
    Circle 1: {A, B, C}
    Circle 2: {D, E}
    """
    def findCircles(n: int, friendships: List[List[int]]) -> int:
        uf = UnionFind(n)
        for a, b in friendships:
            uf.union(a, b)
        return uf.count  # Number of distinct circles
    

Understanding Union Find Deeply

  1. Real-World Analogy

    Think of Union Find like managing social networks:
    
    Scenario: School clubs merging
    - Each student starts in their own group
    - When clubs merge, all members join together
    - Need to quickly check if students are in same club
    
    Example:
    Initial: Art={A,B}, Music={C,D}, Drama={E,F}
    Merge Art & Music:
    - Now {A,B,C,D} is one group
    - Drama {E,F} remains separate
    
    Just like Union Find:
    - Each element starts as its own set
    - Union combines sets
    - Find checks set membership
    
  2. Visual Tree Growth

    How sets combine over time:
    
    Initial state:
    1   2   3   4   5   (each number points to itself)
    
    After union(1,2):
    2←1   3   4   5    (1 points to 2)
    
    After union(3,4):
    2←1   4←3   5      (3 points to 4)
    
    After union(2,4):
        4
       ↙ β†–
      2   3
      ↑
      1
    

Core Implementation

class UnionFind:
    """
    Key operations:
    1. find(x): Find set representative
    2. union(x, y): Unite two sets
    3. connected(x, y): Check if in same set
    
    Optimizations:
    1. Path compression in find()
    2. Union by rank/size
    3. Weighted union
    
    Time complexity:
    - Almost O(1) amortized per operation
    - Ξ±(n) is inverse Ackermann function
    """
    def __init__(self, n: int):
        self.parent = list(range(n))
        self.rank = [0] * n
        self.count = n  # Number of distinct sets
        
    def find(self, x: int) -> int:
        # Path compression
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]
        
    def union(self, x: int, y: int) -> bool:
        px, py = self.find(x), self.find(y)
        if px == py:
            return False
            
        # Union by rank
        if self.rank[px] < self.rank[py]:
            px, py = py, px
        self.parent[py] = px
        if self.rank[px] == self.rank[py]:
            self.rank[px] += 1
            
        self.count -= 1
        return True
        
    def connected(self, x: int, y: int) -> bool:
        return self.find(x) == self.find(y)

Advanced Applications

  1. Minimum Spanning Tree (Kruskal’s Algorithm)

    def minimumSpanningTree(n: int, edges: List[List[int]]) -> int:
        """
        Find minimum cost to connect all nodes
        
        Example:
        Nodes: A, B, C
        Edges: A-B(2), B-C(3), A-C(5)
        
        Process:
        1. Sort edges by weight
        2. Try adding each edge
        3. Skip if creates cycle
        
        Visual steps:
        Step 1: A──(2)──B
        Step 2: B──(3)──C
        Final: A──(2)──B──(3)──C
        """
        uf = UnionFind(n)
        edges.sort(key=lambda x: x[2])  # Sort by weight
        cost = 0
        
        for u, v, weight in edges:
            if uf.union(u, v):
                cost += weight
                
        return cost if uf.count == 1 else -1
    
  2. Network Connectivity

    def criticalConnections(n: int, connections: List[List[int]]) -> List[List[int]]:
        """
        Find bridges in network (critical connections)
        
        Example network:
        1──2──3
        β”‚     β”‚
        └──4β”€β”€β”˜
        
        Critical connections:
        - 2-3 (removing breaks network)
        
        Why Union Find?
        - Efficiently track connected components
        - Detect bridges by removing edges
        """
        def areConnected(skip_edge: int) -> bool:
            uf = UnionFind(n)
            for i, (u, v) in enumerate(connections):
                if i != skip_edge:
                    uf.union(u, v)
            return uf.count == 1
            
        bridges = []
        for i in range(len(connections)):
            if not areConnected(i):
                bridges.append(connections[i])
        return bridges
    

Common Patterns and Techniques

  1. Path Compression

    """
    Before compression:
    1 β†’ 2 β†’ 3 β†’ 4 β†’ 5
    
    After find(1) with compression:
    1 β†’ 5
    2 β†’ 5
    3 β†’ 5
    4 β†’ 5
    5
    
    Benefits:
    - Flattens tree structure
    - Improves future operations
    - Nearly constant time finds
    """
    def find(self, x: int) -> int:
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]
    
  2. Union by Rank

    """
    Without rank:
    Could become linear:
    1
    └── 2
        └── 3
            └── 4
    
    With rank:
    Balanced tree:
        3
      ↙   β†˜
     1     4
      β†˜
       2
    """
    def union(self, x: int, y: int):
        px, py = self.find(x), self.find(y)
        if self.rank[px] < self.rank[py]:
            self.parent[px] = py
        else:
            self.parent[py] = px
            if self.rank[px] == self.rank[py]:
                self.rank[px] += 1
    

Implementation Details

  1. Path Compression Explained

    """
    Why path compression matters:
    
    Without compression:
    find(1) traversal:
    1 β†’ 2 β†’ 3 β†’ 4 β†’ 5  (5 steps)
    1 β†’ 2 β†’ 3 β†’ 4 β†’ 5  (5 steps again)
    
    With compression:
    First find(1):
    1 β†’ 2 β†’ 3 β†’ 4 β†’ 5  (5 steps)
    Updates to:
    1 β†’ 5
    2 β†’ 5
    3 β†’ 5
    4 β†’ 5
    
    Next find(1):
    1 β†’ 5  (just 2 steps!)
    
    Performance impact:
    - Without: O(N) per find
    - With: Nearly O(1) amortized
    """
    def find(self, x: int) -> int:
        if self.parent[x] != x:  # If not root
            # Recursively set parent to root
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]
    
  2. Union by Rank Visualization

    """
    Why rank matters:
    
    Bad case (without rank):
    1
    ↓
    2
    ↓
    3  (Height: 3)
    
    Good case (with rank):
      2
     ↙ β†˜
    1   3  (Height: 2)
    
    Example union sequence:
    1. union(1,2)
       Rank[1] = 0, Rank[2] = 0
       Result: 2←1, Rank[2] = 1
    
    2. union(2,3)
       Rank[2] = 1, Rank[3] = 0
       Result: 
          2
         ↙ β†˜
        1   3
    """
    def union(self, x: int, y: int) -> bool:
        px, py = self.find(x), self.find(y)
        if px == py:  # Already in same set
            return False
            
        # Attach smaller rank tree under root of higher rank tree
        if self.rank[px] < self.rank[py]:
            self.parent[px] = py
        else:
            self.parent[py] = px
            if self.rank[px] == self.rank[py]:
                self.rank[px] += 1  # Increase rank if same
                
        self.count -= 1  # One less set
        return True
    

Common Use Cases

  1. Network Connectivity

    """
    Problem: Check if computer network is fully connected
    
    Example network:
    A──B  C──D
       β†˜    ↙
         E
    
    Steps:
    1. union(A,B)
    2. union(B,E)
    3. union(C,D)
    4. union(D,E)
    
    Check connectivity:
    - find(A) == find(C)  # True, all connected
    - Initially separate components merge
    - Final state: all nodes share same root
    """
    def isNetworkConnected(n: int, connections: List[List[int]]) -> bool:
        uf = UnionFind(n)
        for a, b in connections:
            uf.union(a, b)
        return uf.count == 1  # All nodes in one set
    
  2. Redundant Connections

    """
    Problem: Find redundant edge in graph
    
    Example:
    1──2──3
    β”‚     β”‚
    └──4β”€β”€β”˜
    
    Process:
    1. Try each edge:
       - 1-2: unite sets
       - 2-3: unite sets
       - 3-4: unite sets
       - 4-1: redundant! (already connected)
    
    Why it works:
    - Union Find tracks connectivity
    - When find() returns same root, found cycle
    """
    def findRedundantConnection(edges: List[List[int]]) -> List[int]:
        uf = UnionFind(len(edges) + 1)
        for u, v in edges:
            if not uf.union(u, v):  # Already connected
                return [u, v]
        return []
    

Real-World Applications

  1. Social Network Analysis

    """
    Problem: Friend Recommendation System
    
    Scenario:
    - Users are nodes
    - Friendships are connections
    - Want to find friend groups and suggest connections
    
    Example:
    Users: {Alice, Bob, Charlie, David, Eve}
    Current friends:
    - Alice-Bob
    - Bob-Charlie
    - David-Eve
    
    Implementation:
    """
    class FriendNetwork:
        def __init__(self, users):
            self.uf = UnionFind(len(users))
            self.user_to_id = {user: i for i, user in enumerate(users)}
            
        def add_friendship(self, user1: str, user2: str):
            id1, id2 = self.user_to_id[user1], self.user_to_id[user2]
            self.uf.union(id1, id2)
            
        def are_connected(self, user1: str, user2: str) -> bool:
            id1, id2 = self.user_to_id[user1], self.user_to_id[user2]
            return self.uf.connected(id1, id2)
            
        def suggest_friends(self, user: str) -> List[str]:
            """Suggest friends from same component"""
            user_id = self.user_to_id[user]
            user_group = self.uf.find(user_id)
            return [u for u, i in self.user_to_id.items() 
                   if self.uf.find(i) == user_group 
                   and i != user_id]
    
  2. Grid Connectivity

    """
    Problem: Find connected regions in a grid
    
    Example Grid:
    1 1 0 0 1
    1 1 0 1 0
    0 0 0 0 0
    0 0 0 1 1
    
    Visual representation of regions:
    A A . . B
    A A . C .
    . . . . .
    . . . D D
    
    Implementation strategy:
    1. Each cell is a node
    2. Connect adjacent 1's
    3. Count distinct sets
    """
    def countIslands(grid: List[List[int]]) -> int:
        if not grid: return 0
        rows, cols = len(grid), len(grid[0])
        uf = UnionFind(rows * cols)
        
        def get_id(r: int, c: int) -> int:
            return r * cols + c
            
        # Connect adjacent land cells
        for r in range(rows):
            for c in range(cols):
                if grid[r][c] == 1:
                    for nr, nc in [(r+1,c), (r,c+1)]:  # Right and down
                        if (0 <= nr < rows and 0 <= nc < cols 
                            and grid[nr][nc] == 1):
                            uf.union(get_id(r,c), get_id(nr,nc))
                            
        # Count distinct land masses
        islands = set()
        for r in range(rows):
            for c in range(cols):
                if grid[r][c] == 1:
                    islands.add(uf.find(get_id(r,c)))
                    
        return len(islands)
    

Advanced Optimizations

  1. Path Splitting

    """
    Alternative to full path compression
    
    Instead of making all nodes point to root:
    Make each node point to its grandparent
    
    Before:          After:
    1               1
    ↓               ↓
    2               2
    ↓               β†—
    3        β†’      3
    ↓               ↓
    4               4
    ↓               β†—
    5               5
    
    Implementation:
    """
    def find_with_splitting(self, x: int) -> int:
        while self.parent[x] != x:
            self.parent[x] = self.parent[self.parent[x]]  # Point to grandparent
            x = self.parent[x]
        return x
    
  2. Weighted Union

    """
    Alternative to union by rank
    Keep track of tree sizes instead of heights
    
    Benefits:
    - More balanced trees in practice
    - Easier to maintain
    - Useful for size-based queries
    
    Example:
    Trees:    Size=2     Size=3
              1          4
             ↙           ↙ β†˜
            2           5   6
    
    After union:
         4
        ↙ β†˜
       5   6
      ↙
     1
    ↙
    

2 """ class WeightedUnionFind: def init(self, n: int): self.parent = list(range(n)) self.size = [1] * n # Track size of each tree

def union(self, x: int, y: int) -> bool: px, py = self.find(x), self.find(y) if px == py: return False

Attach smaller tree to larger tree

if self.size[px] < self.size[py]: px, py = py, px self.parent[py] = px self.size[px] += self.size[py] return True


### Practice Problems

1. **Basic**
- Number of Connected Components
- Friend Circles
- Redundant Connection
- Graph Valid Tree

2. **Intermediate**
- Number of Islands II
- Accounts Merge
- Satisfiability of Equations
- Most Stones Removed

3. **Advanced**
- Minimize Malware Spread
- Similar String Groups
- Regions Cut By Slashes
- Largest Component Size

Remember: The key to mastering Union Find is to:
1. Understand when to use it (connectivity problems)
2. Implement optimizations correctly
3. Consider the initialization cost
4. Use appropriate variants for the problem
5. Test with various graph structures