Dynamic ProgrammingProblem 50 of 60

Optimal Binary Search Tree

Brute Force
Time: O(2^n)
Space: O(n)
Optimal
Time: O(n^3)
Space: O(n^2)

Problem Statement

Given n sorted keys and their search frequencies (how often each key is searched), construct a Binary Search Tree (BST) that minimizes the total search cost. The cost of searching a key at depth d is frequency × (d + 1). Find the minimum total search cost.

Example:

  • Input: keys = [10, 12, 20], freq = [34, 8, 50]
  • Output: 142
  • Explanation: The optimal BST has 20 as root, 10 as left child, 12 as right child of 10. Cost = 50×1 + 34×2 + 8×3 = 142.

Noob-Friendly Explanation

Imagine you have a dictionary with words people search for. Some words are super popular (searched a lot), and some are rare. You want to organize these words in a tree structure so that popular words are near the top (quick to find) and rare words can be deeper.

It's like organizing a store: you put the most popular items near the entrance and less popular ones in the back. The goal is to minimize the average time customers spend searching.


Approach 1: Brute Force (Recursion)

Intuition

Try every key as the root. For each choice of root, the left keys form the left subtree and right keys form the right subtree. Recursively find the optimal cost for each subtree.

Algorithm

  1. For keys i to j, try each key r as root.
  2. Cost = cost of left subtree + cost of right subtree + sum of all frequencies in [i..j] (because every node goes one level deeper).
  3. Pick the root that gives minimum total cost.
java
class Solution {
    public int optimalBST(int[] keys, int[] freq) {
        int n = keys.length;
        return solve(freq, 0, n - 1);
    }

    private int solve(int[] freq, int i, int j) {
        if (i > j) return 0;
        if (i == j) return freq[i];

        int freqSum = 0;
        for (int k = i; k <= j; k++) {
            freqSum += freq[k];
        }

        int minCost = Integer.MAX_VALUE;

        // Try each key as root
        for (int r = i; r <= j; r++) {
            int cost = solve(freq, i, r - 1) + solve(freq, r + 1, j) + freqSum;
            minCost = Math.min(minCost, cost);
        }

        return minCost;
    }
}

Complexity Analysis

  • Time Complexity: O(2^n) - Exponential due to overlapping subproblems and trying all roots.
  • Space Complexity: O(n) - Recursion stack depth.

Approach 2: Optimal (Dynamic Programming)

Intuition

Use a 2D DP table where dp[i][j] stores the minimum search cost for keys i to j. We precompute prefix sums of frequencies for quick range sum queries.

Algorithm

  1. Precompute prefix sums of frequencies.
  2. dp[i][i] = freq[i] (single key, cost = its frequency).
  3. For increasing chain lengths, try each key as root and pick the minimum cost.
  4. dp[i][j] = min over all r in [i,j] of (dp[i][r-1] + dp[r+1][j]) + sum(freq[i..j]).
java
class Solution {
    public int optimalBST(int[] keys, int[] freq) {
        int n = keys.length;
        int[][] dp = new int[n][n];

        // Precompute prefix sums for quick range sum
        int[] prefixSum = new int[n + 1];
        for (int i = 0; i < n; i++) {
            prefixSum[i + 1] = prefixSum[i] + freq[i];
        }

        // Base case: single keys
        for (int i = 0; i < n; i++) {
            dp[i][i] = freq[i];
        }

        // Fill for increasing chain lengths
        for (int len = 2; len <= n; len++) {
            for (int i = 0; i <= n - len; i++) {
                int j = i + len - 1;
                dp[i][j] = Integer.MAX_VALUE;

                // Sum of frequencies from i to j
                int freqSum = prefixSum[j + 1] - prefixSum[i];

                // Try each key as root
                for (int r = i; r <= j; r++) {
                    int leftCost = (r > i) ? dp[i][r - 1] : 0;
                    int rightCost = (r < j) ? dp[r + 1][j] : 0;
                    int cost = leftCost + rightCost + freqSum;
                    dp[i][j] = Math.min(dp[i][j], cost);
                }
            }
        }

        return dp[0][n - 1];
    }
}

Complexity Analysis

  • Time Complexity: O(n^3) - Three nested loops (length, start, root).
  • Space Complexity: O(n^2) - 2D DP table of size n×n.