class GfG {
    int findDist(Node root, int a, int b) {
        // Your code here
        Node lca_node = lca(root, a, b);
        int dA = minDistance(root, a, 0);
        int dB = minDistance(root, b, 0);
        return dA + dB; 
    }
    
    int minDistance(Node node, int a, int edges) {
        if(node == null)
            return 0; 
        if(node.data == a)
            return edges; 
  
        int l = minDistance(node.left, a, edges + 1);
        if(l == 0)
            return minDistance(node.right, a, edges + 1);
        return l  ;
    }
    Node lca(Node root, int n1,int n2)
	{
		// Your code here
		if(root == null)
		    return null ; 
		    
	    if(root.data == n1 || root.data == n2)
	        return root;
	    
	    Node left = lca(root.left, n1, n2);
	    Node right = lca(root.right, n1, n2);
	    
	    if(left != null && right != null) 
	        return root;
	    
	    return left != null ? left : right;  
	}
    
    
}
