Binary Search Tree


Binary Search Tree is a binary tree in which for every node has nodes with lower or equal value in its left subtree and nodes with value greater than its value are not in right subtree.

Binary Search Tree is definitely one of the most versatile and useful data structures which has applications in real world problems in searching, sorting, geometric computation among many others. In Computational Geometry k-d tree, interval tree are all augmentation of Binary Search Tree. Some applications are finding 2-D orthogonal intersection, finding intersections of rectangles etc to name just few.

Before even trying to solve problems using Binary Search Tree, it is super important to be able to implement basic Binary Search Tree operations effortlessly. This will come handy later and will make even difficult level problems look very easy.
We would see two implementations of Binary Search Tree in this chapter:
  1. Binary Search Tree that stores (key, value) pair.
  2. Binary Search Tree that stores only value.


#1. Binary Search Tree Implementation storing (key, value) pairs:

In this implementation each node of Binary Search Tree stores a key and a value associated with the key. The Binary Search Tree is built based on key and not the value. This implementation is actually an implementation of a symbol table storing (key, value) pairs. If you want to support more than one value for a key, you could take a list of values instead of a single value in the node.

Java:


public class BST< Key extends Comparable< Key >, Value > {
    private Node root;             // root of BST

    private class Node {
        private Key key;           // sorted by key
        private Value val;         // associated data
        private Node left, right;  // left and right subtrees
        private int size;          // number of nodes in subtree

        public Node(Key key, Value val, int size) {
            this.key = key;
            this.val = val;
            this.size = size;
        }
    }

    // Initializes an empty symbol table.
    public BST() {
    }

    // Returns all keys in the symbol table in the given range an Iterable.
    public Iterable< Key > keys(Key lo, Key hi) {
        if (lo == null) throw new IllegalArgumentException("first argument to keys() is null");
        if (hi == null) throw new IllegalArgumentException("second argument to keys() is null");

        Queue< Key > queue = new Queue< Key >();
        keys(root, queue, lo, hi);
        return queue;
    }

    // Returns the number of keys in the symbol table in the given range.
    public int size(Key lo, Key hi) {
        if (lo == null) throw new IllegalArgumentException("first argument to size() is null");
        if (hi == null) throw new IllegalArgumentException("second argument to size() is null");

        if (lo.compareTo(hi) > 0) return 0;
        if (contains(hi)) return rank(hi) - rank(lo) + 1;
        else              return rank(hi) - rank(lo);
    }

    private void keys(Node x, Queue< Key > queue, Key lo, Key hi) {
        if (x == null) return;
        int cmplo = lo.compareTo(x.key);
        int cmphi = hi.compareTo(x.key);
        if (cmplo < 0) keys(x.left, queue, lo, hi);
        if (cmplo <= 0 && cmphi >= 0) queue.enqueue(x.key);
        if (cmphi > 0) keys(x.right, queue, lo, hi);
    }

    // Return key in BST rooted at x of given rank.
    // Precondition: rank is in legal range.
    private Key select(Node x, int rank) {
        if (x == null) return null;
        int leftSize = size(x.left);
        if      (leftSize > rank) return select(x.left,  rank);
        else if (leftSize < rank) return select(x.right, rank - leftSize - 1);
        else                      return x.key;
    }

    // Return the number of keys in the symbol table strictly less than {@code key}.
    public int rank(Key key) {
        if (key == null) throw new IllegalArgumentException("argument to rank() is null");
        return rank(key, root);
    }

    // Number of keys in the subtree less than key.
    private int rank(Key key, Node x) {
        if (x == null) return 0;
        int cmp = key.compareTo(x.key);
        if      (cmp < 0) return rank(key, x.left);
        else if (cmp > 0) return 1 + size(x.left) + rank(key, x.right);
        else              return size(x.left);
    }

     // Returns all keys in the symbol table as an Iterable.
     // To iterate over all of the keys in the symbol table namedst,
     // use the foreach notation: for (Key key : st.keys()).
    public Iterable< Key > keys() {
        if (isEmpty()) return new Queue();
        return keys(min(), max());
    }

    public boolean contains(Key key) {
        if (key == null) throw new IllegalArgumentException("argument to contains() is null");
        return get(key) != null;
    }

    public Value get(Key key) {
        return get(root, key);
    }

    private Value get(Node x, Key key) {
        if (key == null) throw new IllegalArgumentException("calls get() with a null key");
        if (x == null) return null;
        int cmp = key.compareTo(x.key);
        if      (cmp < 0) return get(x.left, key);
        else if (cmp > 0) return get(x.right, key);
        else              return x.val;
    }

    // Returns true if this symbol table is empty.
    public boolean isEmpty() {
        return size() == 0;
    }

    // Returns the number of key-value pairs in this symbol table
    public int size() {
        return size(root);
    }

    // return number of key-value pairs in BST rooted at x
    private int size(Node x) {
        if (x == null) return 0;
        else return x.size;
    }

    // Inserts the specified key-value pair into the symbol table, overwriting the old
    // value with the new value if the symbol table already contains the specified key.
    // Deletes the specified key (and its associated value) from this symbol table
    // if the specified value is null
    public void put(Key key, Value val) {
        if (key == null) throw new IllegalArgumentException("calls put() with a null key");
        if (val == null) {
            delete(key);
            return;
        }
        root = put(root, key, val);
        assert check();
    }

    private Node put(Node x, Key key, Value val) {
        if (x == null) return new Node(key, val, 1);
        int cmp = key.compareTo(x.key);
        if      (cmp < 0) x.left  = put(x.left,  key, val);
        else if (cmp > 0) x.right = put(x.right, key, val);
        else              x.val   = val;
        x.size = 1 + size(x.left) + size(x.right);
        return x;
    }


    // Removes the smallest key and associated value from the symbol table
    public void deleteMin() {
        if (isEmpty()) throw new NoSuchElementException("Symbol table underflow");
        root = deleteMin(root);
        assert check();
    }

    private Node deleteMin(Node x) {
        if (x.left == null) return x.right;
        x.left = deleteMin(x.left);
        x.size = size(x.left) + size(x.right) + 1;
        return x;
    }

    // Removes the largest key and associated value from the symbol table
    public void deleteMax() {
        if (isEmpty()) throw new NoSuchElementException("Symbol table underflow");
        root = deleteMax(root);
        assert check();
    }

    private Node deleteMax(Node x) {
        if (x.right == null) return x.left;
        x.right = deleteMax(x.right);
        x.size = size(x.left) + size(x.right) + 1;
        return x;
    }
    
    public void delete(Key key) {
        if (key == null) throw new IllegalArgumentException("calls delete() with a null key");
        root = delete(root, key);
        assert check();
    }

    private Node delete(Node x, Key key) {
        if (x == null) return null;

        int cmp = key.compareTo(x.key);
        if      (cmp < 0) x.left  = delete(x.left,  key);
        else if (cmp > 0) x.right = delete(x.right, key);
        else {
            if (x.right == null) return x.left;
            if (x.left  == null) return x.right;
            Node t = x;
            x = min(t.right);
            x.right = deleteMin(t.right);
            x.left = t.left;
        }
        x.size = size(x.left) + size(x.right) + 1;
        return x;
    }


    // Returns the smallest key in the symbol table
    public Key min() {
        if (isEmpty()) throw new NoSuchElementException("calls min() with empty symbol table");
        return min(root).key;
    }

    private Node min(Node x) {
        if (x.left == null) return x;
        else                return min(x.left);
    }

    // Returns the largest key in the symbol table
    public Key max() {
        if (isEmpty()) throw new NoSuchElementException("calls max() with empty symbol table");
        return max(root).key;
    }

    private Node max(Node x) {
        if (x.right == null) return x;
        else                 return max(x.right);
    }

    public Key floor(Key key) {
        if (key == null) throw new IllegalArgumentException("argument to floor() is null");
        if (isEmpty()) throw new NoSuchElementException("calls floor() with empty symbol table");
        Node x = floor(root, key);
        if (x == null) throw new NoSuchElementException("argument to floor() is too small");
        else return x.key;
    }

    private Node floor(Node x, Key key) {
        if (x == null) return null;
        int cmp = key.compareTo(x.key);
        if (cmp == 0) return x;
        if (cmp <  0) return floor(x.left, key);
        Node t = floor(x.right, key);
        if (t != null) return t;
        else return x;
    }

    public Key floor2(Key key) {
        Key x = floor2(root, key, null);
        if (x == null) throw new NoSuchElementException("argument to floor() is too small");
        else return x;

    }

    private Key floor2(Node x, Key key, Key best) {
        if (x == null) return best;
        int cmp = key.compareTo(x.key);
        if      (cmp  < 0) return floor2(x.left, key, best);
        else if (cmp  > 0) return floor2(x.right, key, x.key);
        else               return x.key;
    }
    
    public Key ceiling(Key key) {
        if (key == null) throw new IllegalArgumentException("argument to ceiling() is null");
        if (isEmpty()) throw new NoSuchElementException("calls ceiling() with empty symbol table");
        Node x = ceiling(root, key);
        if (x == null) throw new NoSuchElementException("argument to floor() is too large");
        else return x.key;
    }

    private Node ceiling(Node x, Key key) {
        if (x == null) return null;
        int cmp = key.compareTo(x.key);
        if (cmp == 0) return x;
        if (cmp < 0) {
            Node t = ceiling(x.left, key);
            if (t != null) return t;
            else return x;
        }
        return ceiling(x.right, key);
    }

    public Key select(int rank) {
        if (rank < 0 || rank >= size()) {
            throw new IllegalArgumentException("argument to select() is invalid: " + rank);
        }
        return select(root, rank);
    }

    public int height() {
        return height(root);
    }
    private int height(Node x) {
        if (x == null) return -1;
        return 1 + Math.max(height(x.left), height(x.right));
    }

    public Iterable< Key > levelOrder() {
        Queue< Key > keys = new Queue< Key >();
        Queue< Node > queue = new Queue< Node >();
        queue.enqueue(root);
        while (!queue.isEmpty()) {
            Node x = queue.dequeue();
            if (x == null) continue;
            keys.enqueue(x.key);
            queue.enqueue(x.left);
            queue.enqueue(x.right);
        }
        return keys;
    }

   /*************************************************************************
    *  Check integrity of BST data structure.
    ***************************************************************************/
    private boolean check() {
        if (!isBST())            StdOut.println("Not in symmetric order");
        if (!isSizeConsistent()) StdOut.println("Subtree counts not consistent");
        if (!isRankConsistent()) StdOut.println("Ranks not consistent");
        return isBST() && isSizeConsistent() && isRankConsistent();
    }

    // does this binary tree satisfy symmetric order?
    // Note: this test also ensures that data structure is a binary tree since order is strict
    private boolean isBST() {
        return isBST(root, null, null);
    }

    // is the tree rooted at x a BST with all keys strictly between min and max
    // (if min or max is null, treat as empty constraint)
    // Credit: Bob Dondero's elegant solution
    private boolean isBST(Node x, Key min, Key max) {
        if (x == null) return true;
        if (min != null && x.key.compareTo(min) <= 0) return false;
        if (max != null && x.key.compareTo(max) >= 0) return false;
        return isBST(x.left, min, x.key) && isBST(x.right, x.key, max);
    }

    // are the size fields correct?
    private boolean isSizeConsistent() { return isSizeConsistent(root); }
    private boolean isSizeConsistent(Node x) {
        if (x == null) return true;
        if (x.size != size(x.left) + size(x.right) + 1) return false;
        return isSizeConsistent(x.left) && isSizeConsistent(x.right);
    }

    // check that ranks are consistent
    private boolean isRankConsistent() {
        for (int i = 0; i < size(); i++)
            if (i != rank(select(i))) return false;
        for (Key key : keys())
            if (key.compareTo(select(rank(key))) != 0) return false;
        return true;
    }
    
    public static void main(String[] args) {
        BST< String, Integer > st = new BST< String, Integer >();
        for (int i = 0; !StdIn.isEmpty(); i++) {
            String key = StdIn.readString();
            st.put(key, i);
        }

        for (String s : st.levelOrder())
            System..out.println(s + " " + st.get(s));

        StdOut.println();

        for (String s : st.keys())
            System.out.println(s + " " + st.get(s));
    }
}



Python:



from queue import Queue

class BST:
    class Node:

        def __init__(self, outerInstance, key, val, size):
            self._left = None
            self._right = None
            self._outerInstance = outerInstance
            self._key = key
            self._val = val
            self._size = size

    # Initializes an empty symbol table.
    def __init__(self):
        self._root = None

    # Returns all keys in the symbol table in the given range an Iterable.
    def keys(self, lo, hi):
        if lo is None:
            raise Exception("first argument to keys() is null")
        if hi is None:
            raise Exception("second argument to keys() is null")
        queue = Queue()
        self._keys(self._root, queue, lo, hi)
        return queue

    # Returns the number of keys in the symbol table in the given range.
    def size(self, lo, hi):
        if lo is None:
            raise Exception("first argument to self.size() is null")
        if hi is None:
            raise Exception("second argument to self.size() is null")
        if lo.compareTo(hi) > 0:
            return 0
        if self.contains(hi):
            return self.rank(hi) - self.rank(lo) + 1
        else:
            return self.rank(hi) - self.rank(lo)

    def _keys(self, x, queue, lo, hi):
        if x is None:
            return
        cmplo = lo.compareTo(x.key)
        cmphi = hi.compareTo(x.key)
        if cmplo < 0:
            self._keys(x.left, queue, lo, hi)
        if cmplo <= 0 and cmphi >= 0:
            queue.enqueue(x.key)
        if cmphi > 0:
            self._keys(x.right, queue, lo, hi)

    # Return key in BST rooted at x of given rank.
    # Precondition: rank is in legal range.
    def _select(self, x, rank):
        if x is None:
            return None
        leftSize = self.size(x.left)
        if leftSize > rank:
            return self._select(x.left, rank)
        elif leftSize < rank:
            return self._select(x.right, rank - leftSize - 1)
        else:
            return x.key

    def rank(self, key):
        if key is None:
            raise Exception("argument to rank() is null")
        return self._rank(key, self._root)

    # Number of keys in the subtree less than key.
    def _rank(self, key, x):
        if x is None:
            return 0
        cmp = key.compareTo(x.key)
        if cmp < 0:
            return self._rank(key, x.left)
        elif cmp > 0:
            return 1 + self._size(x.left) + self._rank(key, x.right)
        else:
            return self._size(x.left)

    # Returns all keys in the symbol table as an Iterable.
    # To iterate over all of the keys in the symbol table namedst,
    # use the foreach notation: for (Key key : st.keys()).
    def keys(self):
        if self.isEmpty():
            return Queue()
        return self._keys(min(), max())

    def contains(self, key):
        if key is None:
            raise Exception("argument to contains() is null")
        return self.get(key) is not None

    def get(self, key):
        return self._get(self._root, key)

    def _get(self, x, key):
        if key is None:
            raise Exception("calls get() with a null key")
        if x is None:
            return None
        cmp = key.compareTo(x.key)
        if cmp < 0:
            return self._get(x.left, key)
        elif cmp > 0:
            return self._get(x.right, key)
        else:
            return x.val

    # Returns true if this symbol table is empty.
    def isEmpty(self):
        return self.size() == 0

    # Returns the number of key-value pairs in this symbol table
    def size(self):
        return self._size(self._root)

    # return number of key-value pairs in BST rooted at x
    def _size(self, x):
        if x is None:
            return 0
        else:
            return x.self.size

    # Inserts the specified key-value pair into the symbol table, overwriting
    def put(self, key, val):
        if key is None:
            raise Exception("calls put() with a null key")
        if val is None:
            self.delete(key)
            return
        self.root = self._put(self.root, key, val)
        assert self.check()

    def _put(self, x, key, val):
        if x is None:
            return Node(key, val, 1)
        cmp = key.compareTo(x.key)
        if cmp < 0:
            x.left = self._put(x.left, key, val)
        elif cmp > 0:
            x.right = self._put(x.right, key, val)
        else:
            x.val = val
        x.self.size = 1 + self.size(x.left) + self.size(x.right)
        return x

    # Removes the smallest key and associated value from the symbol table
    def deleteMin(self):
        if self.isEmpty():
            raise Exception("Symbol table underflow")
        self.root = self._deleteMin(self.root)
        assert self.check()

    def _deleteMin(self, x):
        if x.left is None:
            return x.right
        x.left = self._deleteMin(x.left)
        x.self.size = self.size(x.left) + self.size(x.right) + 1
        return x

    # Removes the largest key and associated value from the symbol table
    def deleteMax(self):
        if self.isEmpty():
            raise Exception("Symbol table underflow")
        self.root = self._deleteMax(self.root)
        assert self.check()

    def _deleteMax(self, x):
        if x.right is None:
            return x.left
        x.right = self._deleteMax(x.right)
        x.self.size = self.size(x.left) + self.size(x.right) + 1
        return x

    def delete(self, key):
        if key is None:
            raise Exception("calls delete() with a null key")
        root = self._delete(self._root, key)
        assert self.check()

    def _delete(self, x, key):
        if x is None:
            return None
        cmp = key.compareTo(x.key)
        if cmp < 0:
            x.left = self._delete(x.left, key)
        elif cmp > 0:
            x.right = self._delete(x.right, key)
        else:
            if x.right is None:
                return x.left
            if x.left is None:
                return x.right
            t = x
            x = self._min(t.right)
            x.right = self.deleteMin(t.right)
            x.left = t.left
        x.size = self.size(x.left) + self.size(x.right) + 1
        return x

    # Returns the smallest key in the symbol table
    def min(self):
        if self.isEmpty():
            raise Exception("calls min() with empty symbol table")
        return self._min(self.root).key

    def _min(self, x):
        if x.left is None:
            return x
        else:
            return self._min(x.left)

    # Returns the largest key in the symbol table
    def max(self):
        if self.isEmpty():
            raise Exception("calls max() with empty symbol table")
        return self._max(self.root).key

    def _max(self, x):
        if x.right is None:
            return x
        else:
            return self._max(x.right)

    def floor(self, key):
        if key is None:
            raise Exception("argument to floor() is null")
        if self.isEmpty():
            raise Exception("calls floor() with empty symbol table")
        x = self._floor(self.root, key)
        if x is None:
            raise Exception("argument to floor() is too small")
        else:
            return x.key

    def _floor(self, x, key):
        if x is None:
            return None
        cmp = key.compareTo(x.key)
        if cmp == 0:
            return x
        if cmp < 0:
            return self._floor(x.left, key)
        t = self._floor(x.right, key)
        if t is not None:
            return t
        else:
            return x

    def floor2(self, key):
        x = self.floor2(self.root, key, None)
        if x is None:
            raise Exception("argument to floor() is too small")
        else:
            return x

    def _floor2(self, x, key, best):
        if x is None:
            return best
        cmp = key.compareTo(x.key)
        if cmp < 0:
            return self._floor2(x.left, key, best)
        elif cmp > 0:
            return self._floor2(x.right, key, x.key)
        else:
            return x.key

    def ceiling(self, key):
        if key is None:
            raise Exception("argument to ceiling() is null")
        if self.isEmpty():
            raise Exception("calls ceiling() with empty symbol table")
        x = self._ceiling(self.root, key)
        if x is None:
            raise Exception("argument to floor() is too large")
        else:
            return x.key

    def _ceiling(self, x, key):
        if x is None:
            return None
        cmp = key.compareTo(x.key)
        if cmp == 0:
            return x
        if cmp < 0:
            t = self._ceiling(x.left, key)
            if t is not None:
                return t
            else:
                return x
        return self._ceiling(x.right, key)

    def select(self, rank):
        if rank < 0 or rank >= self.size():
            raise Exception("argument to select() is invalid: " + str(rank))
        return self.select(self.root, rank)

    def height(self):
        return self._height(self.root)

    def _height(self, x):
        if x is None:
            return -1
        return 1 + max(self._height(x.left), self._height(x.right))

    def levelOrder(self):
        keys = Queue()
        queue = Queue()
        queue.enqueue(self.root)
        while not queue.isEmpty():
            x = queue.dequeue()
            if x is None:
                continue
            keys.enqueue(x.key)
            queue.enqueue(x.left)
            queue.enqueue(x.right)
        return keys

    def _check(self):
        if not self.isBST():
            print("Not in symmetric order")
        if not self._isSizeConsistent():
            print("Subtree counts not consistent")
        if not self._isRankConsistent():
            print("Ranks not consistent")
        return self.isBST() and self._isSizeConsistent() and self._isRankConsistent()

    # does this binary tree satisfy symmetric order?
    # Note: this test also ensures that data structure is a binary tree since
    def isBST(self):
        return self._isBST(self._root, None, None)

    # is the tree rooted at x a BST with all keys strictly between min and max
    # (if min or max is null, treat as empty constraint)
    # Credit: Bob Dondero's elegant solution
    def _isBST(self, x, min, max):
        if x is None:
            return True
        if min is not None and x.key.compareTo(min) <= 0:
            return False
        if max is not None and x.key.compareTo(max) >= 0:
            return False
        return self._isBST(x.left, min, x.key) and self._isBST(x.right, x.key, max)

    # are the size fields correct?
    def isSizeConsistent(self):
        return self._isSizeConsistent(self.root)

    def _isSizeConsistent(self, x):
        if x is None:
            return True
        if x.size != self.size(x.left) + self.size(x.right) + 1:
            return False
        return self._isSizeConsistent(x.left) and self._isSizeConsistent(x.right)  # check that ranks are consistent

    def _isRankConsistent(self):
        i = 0
        while i < self.size():
            if i != self.rank(self.select(i)):
                return False
            i += 1
        for key in self.keys():
            if key.compareTo(self.select(self.rank(key))) != 0:
                return False
        return True


def main():
    st = BST()
    i = 0
    while not st.isEmpty():
        key = input()
        st.put(key, i)
        i += 1
    for s in st.levelOrder():
        print(s + " " + st.get(s))
    print()
    for s in st.keys():
        print(s + " " + st.get(s))



#2. Binary Search Tree Implementation storing only values and not keys:


In this implementation we would be storing only values (one value per node) in the nodes and the Binary Search Tree would be built based on the values of the nodes.

Java:


// This BST implementation accepts duplicate values
// leftNode.value <= currentNode.value < rightNode.value
// value <= root-> value goes to left subtree (equal values go to left)
// value < root-> value goes to right subtree

// But if it is (key, value) pair then BST should NOT accept duplicate entries
// rather it should replace the previous value when a new entry with some duplicate
// keys are put in the BST

public class Node {
      public int value;
      public int size = 1;  //total number of nodes in the subtree
                            //with this node as root which includes this node too
      public Node left = null;
      public Node right = null;
      
      public Node(int value) {
        this.value = value;
      }
    }
    

public class BST {
    
    private Node root = null;  //root of BST
    //private int size = 0;
    //we don't need size, since we can get this by invoking root.size
    
    public int size() {
        return size(root);
    }
    
    private int size(Node node) {
          if (node == null) return 0;
          return size(node.left) + 1 + size(node.right);  //OR // return node.size;
    }
    
    public boolean isEmpty() {
        return size() == 0;
    }
    
    public void insert(int value) {
      if (root == null) {  // OR  //if (isEmpty)
        root = new Node(value);
        return;
      }
      
      insert(root, value);
    }
    
    private void insert(Node node, int value) {
      int compare = value.compareTo(node.value);
      
      if (compare > 0) {  
        if (node.right != null)     insert(node.right, value);
        else                        node.right = new Node(value);
      }
      if (compare <= 0) {     //equal values go to the left
        if (node.left != null)      insert(node.left, value);
        else                        node.left = new Node(value);
      }
      
      node.size++;
    }
    
    public void delete(int value) {
          root = delete(root, value); //delete(root, value) should return root because this way 
                                      //the case in which the root node is deleted is 
                                      //also taken care of
    }
    
    // Deletes the node and returns the new root (or the old root if the root is unchanged by the delete operation)
    private Node delete(Node node, int value) {
          if (node == null) return null;
          
          int compare = value.compareTo(node.value);
          
          if (compare > 0) {
                
                node.right = delete(node.right, value);
                
          } else if (compare < 0) {
                
                node.left = delete(node.left, value);
                
          } else {  //this else part also takes care of deleting root
                
                if (node.right == null) return node.left;
                if (node.left == null) return node.right;
                Node n = node;
                node = min(n.right);
                node.right = deleteMin(n.right);
                node.left = n.left;
                
          } 
          
         node.size = node.left.size + 1 + node.right.size;  //the size should be decremented only when 
                        //the value to be deleted is present in the Tree.
                        //so node.size-- is not applicable here
         return node;  //if the root is deleted then the root is updated in the else block
    }
    
    public void deleteMin() throws Exception {
          if (isEmpty) throw Exception("Tree is Empty");
          root = deleteMin(root);
    }
   
    // Deletes Min and returns the root  
    public Node deleteMin(Node node) {
          if (node.left == null) return node.right; // short circuit when root is the min
          Node n = node;
          while  (node.left.left != null) {
                node.size--;
                node = node.left;
          }
          node.size--;
          node.left = node.left.right;
          return n;  //returns root
    }
    
    public int rank(int value) throws Exception {  //rank starts from 1 not 0, in this implementation
          return rank(root);
    }
    
    private int rank(Node node, int value) throws Exception {
          if (node == null) throw new Exception("Value not present in the Tree");
          int compare = value.compareTo(node.value);
          if (compare < 0) return rank(node.left, value);
          else if (compare > 0) return node.left.size + 1 + rank(node.right, value);
          else {
                return node.left.size + 1;
          }
    }
    
    public int select(int index) throws Exception { //index starts from 1 not 0
          if (index < 1 || index > size(root)) throw new Exception("Invalid index");
          return select(root, index);
    }
    
    private int select(Node node, int index) {
          int rank = node.left.size + 1;
          if (index < rank) return select(node.left, index);
          else if (index > rank) return select(node.right, index - rank);
          else return node.value;
    }
    
    public Node min() {
          return min(root);
    }
    private Node min(Node node) {
          if (node == null) return null;
          if (node.left == null) return node;
          else return min(node.left);
    }
    
    public boolean contains(int value) {
          return contains(root, value);
    }
    
    private boolean contains(Node node, int value) {
          if (node == null) return false;
          if (node.value > value) return contains(node.left, value);
          else if (node.value < value) return contains(node.right, value);
          else return true;
    }
    
    public int height() {
          return height(root);
    }
    
    private int height(Node node) {
          if (node == null) return 0;
          return Math.max(height(node.left), height(node.right)) + 1;
    }
    
    //Recursive implementation of deleteMin
    public Node deleteMinRECURSIVE(Node node) {
          if (node.left == null) return node.right; // short circuit when root is the min
          node.left = deleteMinRECURSIVE(node.left);
          node.size = node.left.size + 1 + node.right.size;
          return node;
    }
    
    
}



Python:


from Node import Node


class BST:

    def __init__(self):
        self._root = None
        self.node = None
        self.min = None
        self.n = None

    # we don't need size, since we can get this by invoking root.size
    def size(self):
        return self._size(self._root)

    def _size(self, node):
        if node is None:
            return 0
        return self._size(node.left) + 1 + self._size(node.right)  # OR // return
        node.size

    def isEmpty(self):
        return self.size() == 0

    def insert(self, value):
        if self._root is None:
            self._root = Node(value)
            return
        self._insert(self._root, value)

    def _insert(self, node, value):
        compare = value.compareTo(node.value)
        if compare > 0:
            if node.right is not None:
                self._insert(node.right, value)
            else:
                node.right = Node(value)
        if compare <= 0:
            if node.left is not None:
                self._insert(node.left, value)
            else:
                node.left = Node(value)
        node.size += 1

    def delete(self, value):
        self._root = self.delete(self._root,
                                 value)  # delete(root, value) should return root because self way deleted is_ also taken care of

    # Deletes the node and returns the new root (or the old root if the root is
    def delete(self, node, value):
        if node is None:
            return None
        compare = value.compareTo(node.value)
        if compare > 0:
            node.right = self.delete(node.right, value)
        elif compare < 0:
            node.left = self.delete(node.left, value)
        else:
            if node.right is None:
                return node.left
            if node.left is None:
                return node.right
            n = node
            node = min(n.right)
            node.right = self.deleteMin(n.right)
            node.left = n.left
        node.size = node.left.size + 1 + node.right.size  # the size should be decremented only when
        return node  # if the root is deleted then the root is updated in the

    def deleteMin(self):
        if self.isEmpty:
            raise Exception("Tree is Empty")
        self._root = self.deleteMin(self._root)
        # Deletes Min and returns the root public Node deleteMin(Node node) {
        if self.node.left is None:
            return self.node.right  # short circuit when root is the self.min Node
            self.n = node
        while node.left.left is not None:
            node.size -= 1
            node = node.left
        node.size -= 1
        node.left = node.left.right
        return self.n  # returns root

    def rank(self, value):
        return self.rank(self._root)

    def _rank(self, node, value):
        if node is None:
            raise Exception("Value not present in the Tree")
        compare = value.compareTo(node.value)
        if compare < 0:
            return self.rank(node.left, value)
        elif compare > 0:
            return node.left.size + 1 + self.rank(node.right, value)
        else:
            return node.left.size + 1

    def select(self, index):
        if index < 1 or index > self._size(self._root):
            raise Exception("Invalid index")
        return self.select(self._root, index)

    def _select(self, node, index):
        rank = node.left.size + 1
        if index < rank:
            return self._select(node.left, index)
        elif index > rank:
            return self._select(node.right, index - rank)
        else:
            return node.value

    def min(self):
        return self._min(self.root)

    def _min(self, node):
        if node is None:
            return None
        if node.left is None:
            return node
        else:
            return self._min(node.left)

    def contains(self, value):
        return self._contains(self.root, value)

    def _contains(self, node, value):
        if node is None:
            return False
        if node.value > value:
            return self._contains(node.left, value)
        elif node.value < value:
            return self._contains(node.right, value)
        else:
            return True

    def height(self):
        return self._height(self.root)

    def _height(self, node):
        if node is None:
            return 0
        return max(self._height(node.left), self._height(node.right)) + 1

    # Recursive implementation of deleteMin
    def deleteMinRECURSIVE(self, node):
        if node.left is None:
            return node.right  # short circuit when root is
        node.left = self.deleteMinRECURSIVE(node.left)
        node.size = node.left.size + 1 + node.right.size
        return node



Now to quickly show you how becoming comfortable implementing basic BST operations helps you in solving real world problem, let's take a look at the below problem:

Imagine you're reading an integer stream. Periodically, you want to be able to look up the rank of a number say x. Rank = number of values less than or equal to x).Implement a method track(int x) which is called when each number is generated in the integer stream, and the method getRankOfNumber(int x) which returns number of values less or equal to x, not including x itself.
Integer Stream: 7, 8, 3, 8, 2, 5, 5, 9
getRankOf(2) = 0
getRankOf(5) = 3


The core of the solution for this problem is nothing but the rank(value) implementation we just did when implementing basic operations of BST.

Java:


public class BSTNode {
    public Node left;
    public Node right;
    public int sizeOfLeftSubtree;
    public int val;

    public  BSTNode(int val) {
        this.val = val;
    }

    public BSTNode(int val, Node left, Node right) {
        this.val = val;
        this.left = left;
        this.right = right;
    }

    public void insert(int num) {
        if (num <= val) {
            if (left == null) {
                left = new Node(num);
            }
            else {
                left.insert(num);
            }
            sizeOfLeftSubtree++;
        }
        else {
            if (right == null) {
                right = new Node(num);
            }
            else {
                right.insert(num);
            }
        }
    }

    public int getRank(int num) {
        if (this.val == num) {
            return this.sizeOfLeftSubtree;
        }
        else if (num < this.val) {
            if ()
            return left.getRank(num);
        }
        else {
            if (right == null) return -1;
            int rightRank = right.getRank(num);
            return rightRank == -1 ? -1 : this.sizeOfLeftSubtree + 1 + rightRank;
        }
    }
}

public class IntegerStream {
    BSTNode root;

    public void track(int num) {
        if (root == null) {
            root = new BSTNode(num);
        }
        else {
            root.insert(num);
        }
    }

    public int getRankOf(int num) {
        return root.getRank(num);
    }
}




Python:


from Node import Node


class BSTNode:
    def _initialize_instance_fields(self):
        self.left = None
        self.right = None
        self.sizeOfLeftSubtree = 0
        self.val = 0

    def __init__(self, val, left, right):
        self._initialize_instance_fields()
        self.val = val
        self.left = left
        self.right = right

    def insert(self, num):
        if num <= self.val:
            if self.left is None:
                self.left = Node(num)
            else:
                self.left.insert(num)
            self.sizeOfLeftSubtree += 1
        else:
            if self.right is None:
                self.right = Node(num)
            else:
                self.right.insert(num)

    def getRank(self, num):
        if self.val == num:
            return self.sizeOfLeftSubtree
        elif num < self.val:
            if num:
                return self.left.getRank(num)
        else:
            if self.right is None:
                return -1
            rightRank = self.right.getRank(num)
            return -1 if rightRank == -1 else self.sizeOfLeftSubtree + 1 + rightRank



Time Complexity:

Track Method:same as a BST insert. O(logn) where n is the total number of integers gotten so far from the stream.
GetRank Method: same as BST getRank method.O(logn).


Instructor:





Help Your Friends save 25% on our products

wave