Friday, September 4, 2015

Find the distance between 2 nodes in Binary Tree


Find the distance between two keys in a binary tree, no parent pointers are given. Distance between two nodes is the minimum number of edges to be traversed to reach one node from other.

Dist(-4,3) = 2,
Dist (-4,19) = 4
Dist(21,-4) = 3
Dist(2,-4) = 1


The distance between two nodes can be obtained in terms of lowest common ancestor. Following is the formula.

Dist(n1, n2) = Dist(root, n1) + Dist(root, n2) - 2*Dist(root, lca) 
'n1' and 'n2' are the two given keys
'root' is root of given Binary Tree.
'lca' is lowest common ancestor of n1 and n2
Dist(n1, n2) is the distance between n1 and n2.

Example take the case of Dist(-4,3)
LCA(-4,3) = 2
Dist(-4,3) = Dist(5,-4)+Dist(5,3) - 2 * (5,2) = 3 + 3 - 2 * 2 = 2

Now lets do the coding.


// Returns level of key k if it is present in tree, otherwise returns -1
int findLevel(Node root, int k, int level)
    // Base Case
    if (root == null)
        return -1;
    // If key is present at root, or in left subtree or right subtree,
    // return true;
    if (root.key == k)
        return level;
    int l = findLevel(root.left, k, level+1);
    return (l != -1)? l : findLevel(root.right, k, level+1);
// This function returns pointer to LCA of two given values n1 and n2. 
// It also sets d1, d2 and dist if one key is not ancestor of other
// Note that we set the value in findDistUtil for d1,d2 and dist
// d1 -. To store distance of n1 from root
// d2 -. To store distance of n2 from root
// lvl -. Level (or distance from root) of current node
// dist -. To store distance between n1 and n2
Node findDistUtil(Node root, int n1, int n2, Integer d1, Integer d2, 
                   Integer dist, int lvl)
    // Base case
    if (root == null) return null;
    // If either n1 or n2 matches with root's key, report
    // the presence by returning root (Note that if a key is
    // ancestor of other, then the ancestor key becomes LCA
    if (root.key == n1)
         d1 = lvl;
         return root;
    if (root.key == n2)
         d2 = lvl;
         return root;
    // Look for n1 and n2 in left and right subtrees
    Node left_lca  = findDistUtil(root.left, n1, n2, d1, d2, dist, lvl+1);
    Node right_lca = findDistUtil(root.right, n1, n2, d1, d2, dist, lvl+1);
    // If both of the above calls return Non-null, then one key
    // is present in once subtree and other is present in other,
    // So this node is the LCA
    if (left_lca!=null && right_lca!=null)
        dist = d1 + d2 - 2*lvl;
        return root;
    // Otherwise check if left subtree or right subtree is LCA
    return (left_lca != null)? left_lca: right_lca;
// The main function that returns distance between n1 and n2
// This function returns -1 if either n1 or n2 is not present in
// Binary Tree.
int findDistance(Node root, int n1, int n2)
    // Initialize d1 (distance of n1 from root), d2 (distance of n2 
    // from root) and dist(distance between n1 and n2)
    Integer d1 = -1, d2 = -1, dist;
    Node lca = findDistUtil(root, n1, n2, d1, d2, dist, 1);
    // If both n1 and n2 were present in Binary Tree, return dist
    if (d1 != -1 && d2 != -1)
        return dist;
    // If n1 is ancestor of n2, consider n1 as root and find level 
    // of n2 in subtree rooted with n1
    if (d1 != -1)
        dist = findLevel(lca, n2, 0);
        return dist;
    // If n2 is ancestor of n1, consider n2 as root and find level 
    // of n1 in subtree rooted with n2
    if (d2 != -1)
        dist = findLevel(lca, n1, 0);
        return dist;
    return -1;

findDistance() is the main function which calculates the distance, which calls findDistUtil which calculates distance as well as find the LCA in case n1 is not the ancestor of n2 or vice versa.
If n1 is ancestor of n2 or vice versa, we use findLevel to simply find the difference between 2 levels.

Time Complexity - O(n) as we do single traversal on the tree

Note that in java we dont have out parameters in function, like we have in c#. Hence I have used Integer Object, so that I can set the value in d1,d2 and dist as we have pass by value for primitive types in java, but we needed pass by reference.



  1. cant the solution be simply distance of n1 and n2 from lca.
    dist(n1,n2)= dist(lca,n1)+dist(lca,n2)? correct me if i am wrong.

    1. You are absolutely correct. What is being done in the above solution is, dist(n1,n2) = dist(root,n1) + dist(root,n2) - 2*dist(root,lca) as root is point of reference. Thanks.