Binary TreesProblem 32 of 35

Find distance between 2 nodes in a Binary tree

Brute Force
Time: O(n)
Space: O(n)
Optimal
Time: O(n)
Space: O(h)

Problem Statement

Given a binary tree and two nodes, find the distance between them. The distance is defined as the number of edges on the path between the two nodes.

Example:

1 / \ 2 3 / \ \ 4 5 6 / 7 Distance(4, 5) = 2 (4 → 2 → 5) Distance(4, 6) = 4 (4 → 2 → 1 → 3 → 6) Distance(7, 6) = 5 (7 → 5 → 2 → 1 → 3 → 6)

Approach 1: Brute Force (Find Paths and Calculate)

Intuition

Find the path from root to both nodes. The distance is the sum of unique path lengths minus the common prefix (which appears twice).

Algorithm

  1. Find path from root to node1
  2. Find path from root to node2
  3. Find the LCA (last common node in both paths)
  4. Distance = len(path1) + len(path2) - 2 * len(common path)
java
import java.util.*;

class TreeNode {
    int val;
    TreeNode left, right;
    TreeNode(int x) { val = x; }
}

public class Solution {
    private boolean findPath(TreeNode root, int target, List<TreeNode> path) {
        if (root == null) return false;
        
        path.add(root);
        
        if (root.val == target) return true;
        
        if (findPath(root.left, target, path) || 
            findPath(root.right, target, path)) {
            return true;
        }
        
        path.remove(path.size() - 1);
        return false;
    }
    
    public int distanceBruteForce(TreeNode root, int node1, int node2) {
        List<TreeNode> path1 = new ArrayList<>();
        List<TreeNode> path2 = new ArrayList<>();
        
        if (!findPath(root, node1, path1) || !findPath(root, node2, path2)) {
            return -1;  // One or both nodes not found
        }
        
        // Find where paths diverge
        int i = 0;
        while (i < path1.size() && i < path2.size() && path1.get(i) == path2.get(i)) {
            i++;
        }
        
        // Distance = remaining length of both paths
        return (path1.size() - i) + (path2.size() - i);
    }
}

Complexity Analysis

  • Time Complexity: O(n) - Two traversals to find paths
  • Space Complexity: O(n) - Storing paths

Approach 2: Optimal (Using LCA)

Intuition

Distance = Distance(LCA, node1) + Distance(LCA, node2). First find the LCA, then calculate distance from LCA to each node.

Formula

Distance(node1, node2) = Distance(root, node1) + Distance(root, node2) - 2 * Distance(root, LCA)

Or equivalently: Distance(node1, node2) = Level(node1) + Level(node2) - 2 * Level(LCA)

Algorithm

  1. Find LCA of node1 and node2
  2. Find distance from LCA to node1
  3. Find distance from LCA to node2
  4. Return sum of both distances
java
class TreeNode {
    int val;
    TreeNode left, right;
    TreeNode(int x) { val = x; }
}

public class Solution {
    private TreeNode findLCA(TreeNode root, int p, int q) {
        if (root == null) return null;
        if (root.val == p || root.val == q) return root;
        
        TreeNode left = findLCA(root.left, p, q);
        TreeNode right = findLCA(root.right, p, q);
        
        if (left != null && right != null) return root;
        return left != null ? left : right;
    }
    
    // Find distance from source to target (going down)
    private int findDistance(TreeNode root, int target, int distance) {
        if (root == null) return -1;
        
        if (root.val == target) return distance;
        
        int left = findDistance(root.left, target, distance + 1);
        if (left != -1) return left;
        
        return findDistance(root.right, target, distance + 1);
    }
    
    public int distanceBetweenNodes(TreeNode root, int node1, int node2) {
        // Find LCA
        TreeNode lca = findLCA(root, node1, node2);
        
        if (lca == null) return -1;
        
        // Find distances from LCA to both nodes
        int dist1 = findDistance(lca, node1, 0);
        int dist2 = findDistance(lca, node2, 0);
        
        return dist1 + dist2;
    }
}

Single Pass Solution

Complexity Analysis

  • Time Complexity: O(n) - Single traversal to find LCA and distances
  • Space Complexity: O(h) - Recursion stack depth

Key Takeaways

  1. LCA is crucial: Distance goes through LCA - it's the meeting point
  2. Formula: Distance = dist(LCA, node1) + dist(LCA, node2)
  3. Alternative formula: dist = level(p) + level(q) - 2*level(LCA)
  4. Single pass possible: Find LCA and distances simultaneously
  5. Edge cases: Same node (distance = 0), one is ancestor of other