Binary TreesProblem 32 of 35
Find distance between 2 nodes in a Binary tree
Problem Statement
Given a binary tree and two nodes, find the distance between them. The distance is defined as the number of edges on the path between the two nodes.
Example:
1
/ \
2 3
/ \ \
4 5 6
/
7
Distance(4, 5) = 2 (4 → 2 → 5)
Distance(4, 6) = 4 (4 → 2 → 1 → 3 → 6)
Distance(7, 6) = 5 (7 → 5 → 2 → 1 → 3 → 6)
Approach 1: Brute Force (Find Paths and Calculate)
Intuition
Find the path from root to both nodes. The distance is the sum of unique path lengths minus the common prefix (which appears twice).
Algorithm
- Find path from root to node1
- Find path from root to node2
- Find the LCA (last common node in both paths)
- Distance = len(path1) + len(path2) - 2 * len(common path)
java
import java.util.*;
class TreeNode {
int val;
TreeNode left, right;
TreeNode(int x) { val = x; }
}
public class Solution {
private boolean findPath(TreeNode root, int target, List<TreeNode> path) {
if (root == null) return false;
path.add(root);
if (root.val == target) return true;
if (findPath(root.left, target, path) ||
findPath(root.right, target, path)) {
return true;
}
path.remove(path.size() - 1);
return false;
}
public int distanceBruteForce(TreeNode root, int node1, int node2) {
List<TreeNode> path1 = new ArrayList<>();
List<TreeNode> path2 = new ArrayList<>();
if (!findPath(root, node1, path1) || !findPath(root, node2, path2)) {
return -1; // One or both nodes not found
}
// Find where paths diverge
int i = 0;
while (i < path1.size() && i < path2.size() && path1.get(i) == path2.get(i)) {
i++;
}
// Distance = remaining length of both paths
return (path1.size() - i) + (path2.size() - i);
}
}Complexity Analysis
- Time Complexity: O(n) - Two traversals to find paths
- Space Complexity: O(n) - Storing paths
Approach 2: Optimal (Using LCA)
Intuition
Distance = Distance(LCA, node1) + Distance(LCA, node2). First find the LCA, then calculate distance from LCA to each node.
Formula
Distance(node1, node2) = Distance(root, node1) + Distance(root, node2) - 2 * Distance(root, LCA)
Or equivalently:
Distance(node1, node2) = Level(node1) + Level(node2) - 2 * Level(LCA)
Algorithm
- Find LCA of node1 and node2
- Find distance from LCA to node1
- Find distance from LCA to node2
- Return sum of both distances
java
class TreeNode {
int val;
TreeNode left, right;
TreeNode(int x) { val = x; }
}
public class Solution {
private TreeNode findLCA(TreeNode root, int p, int q) {
if (root == null) return null;
if (root.val == p || root.val == q) return root;
TreeNode left = findLCA(root.left, p, q);
TreeNode right = findLCA(root.right, p, q);
if (left != null && right != null) return root;
return left != null ? left : right;
}
// Find distance from source to target (going down)
private int findDistance(TreeNode root, int target, int distance) {
if (root == null) return -1;
if (root.val == target) return distance;
int left = findDistance(root.left, target, distance + 1);
if (left != -1) return left;
return findDistance(root.right, target, distance + 1);
}
public int distanceBetweenNodes(TreeNode root, int node1, int node2) {
// Find LCA
TreeNode lca = findLCA(root, node1, node2);
if (lca == null) return -1;
// Find distances from LCA to both nodes
int dist1 = findDistance(lca, node1, 0);
int dist2 = findDistance(lca, node2, 0);
return dist1 + dist2;
}
}Single Pass Solution
Complexity Analysis
- Time Complexity: O(n) - Single traversal to find LCA and distances
- Space Complexity: O(h) - Recursion stack depth
Key Takeaways
- LCA is crucial: Distance goes through LCA - it's the meeting point
- Formula: Distance = dist(LCA, node1) + dist(LCA, node2)
- Alternative formula: dist = level(p) + level(q) - 2*level(LCA)
- Single pass possible: Find LCA and distances simultaneously
- Edge cases: Same node (distance = 0), one is ancestor of other