solution_060b.py

#!/usr/bin/python3
# ====================================================================
# binary tree - build and balance trees
# ====================================================================

import user_interface as ui

# --------------------------------------------------------------------
# ---- tree node class
# ----
# ---- Traditionally the child nodes of a parent node are named
# ---- left and right. Left is the sub-tree of all of the
# ---- modes less than the parent node. Right is a subtree of
# ---- all of the nodes greater (or equal) to the parent node.
# --------------------------------------------------------------------

class Node:

    def __init__(self,value):
        self.value  = value     # node value
        self.parent = None      # parent node
        self.left   = None      # left child
        self.right  = None      # right child
        self.depth  = 0         # depth of node from tree root


# --------------------------------------------------------------------
# ---- display node
# --------------------------------------------------------------------

def display_node(node):
    
    print(f'------ node {node.value:<3}')

    if node.parent is None:
        print('parent None')
    else:
        print(f'parent {node.parent.value}')
        
    if node.left is None:
        print('left   None')
    else:
        print(f'left   {node.left.value}')

    if node.right is None:
        print('right  None')
    else:
        print(f'right   {node.right.value}')

    print(f'depth  {node.depth}')

    print(f'---------------')

# --------------------------------------------------------------------
# ---- calculate the maximum height of the tree
# ---- (call function recursively)
# --------------------------------------------------------------------

def height(node):

    if node is None: return 0
    h_left  = height(node.left)
    h_right = height(node.right)
    if h_left > h_right:
        return h_left + 1
    return h_right + 1        


# --------------------------------------------------------------------
# ---- calculate a node's depth from the root node
# --------------------------------------------------------------------

def node_depth(node):
    dep = 0
    while node.parent is not None:
        dep += 1
        node = node.parent
    return dep

# --------------------------------------------------------------------
# ---- traverse a tree in order - return array(s)
# ---- (call function recursively)
# --------------------------------------------------------------------

def traverse_tree_in_order(node,nodes,values):

    if node is None: return
    traverse_tree_in_order(node.left,nodes,values)
    if nodes is not None:   nodes.append(node)
    if values is not None: values.append(node.value)
    traverse_tree_in_order(node.right,nodes,values)


# --------------------------------------------------------------------
# ---- display tree node values in order
# --------------------------------------------------------------------

def display_node_values(root):

    # ---- get a list of tree values in order
    
    values = []
    traverse_tree_in_order(root,None,values)

    print()
    if len(values) == 0:
        print('value list is empty')
    else:
        for v in values:
            print(f'{v} ',end='')    
        print('\n')

# --------------------------------------------------------------------
# ---- balance a binary tree
# ---- built a tree from a list of nodes ordered by value
# ---- (set up and then call a recursive function)
# ---- based on:
# ---- www.geeksforgeeks.org/convert-normal-bst-balanced-bst/
# --------------------------------------------------------------------

def construct_balanced_tree(root):

    def balanced_tree_util(nodes,start,end,parent):

        if start > end: return None
        
        # ---- get middle node
        
        mid = (start + end)//2        
        node = nodes[mid]

        # ---- construct/add to balanced tree
        # ---- a. fill in root node's value
        # ---- b. add a new node to the tree

        if parent.value is None:
            parent.value = node.value    
        else:
            add_node(Node(node.value),parent)
        
        # ---- construct left and right subtrees

        balanced_tree_util(nodes,start,mid-1,parent)
        balanced_tree_util(nodes,mid+1,end,parent)

        return

    # ---- get a list of existing tree nodes in order

    nodes = []
    traverse_tree_in_order(root,nodes,None)

    # ---- create a root node
    
    new_root = Node(None)
    
    # ---- create a new ballanced tree

    n = len(nodes)
    balanced_tree_util(nodes,0,n-1,new_root)

    # ---- return the balanced tree

    return new_root       


# --------------------------------------------------------------------
# ---- display a tree (horizontally)
# ---- (call support function recursively)
# --------------------------------------------------------------------

def display_tree(max_width,tree,indent=0):

    if tree is None: return

    # ---- support function
    def print_line(max_width,node,indent):
        line = ' '*indent + '-'*(node.depth*4) + ' ' + str(node.value)
        if len(line) > max_width: return
        print(line)
 
    # ---- support function
    def display(node):
        print(f'node={node.value:<4}  depth={node.depth}')

    def traverse(max_width,node,indent):
        if node is None: return
        traverse(max_width,node.right,indent)
        print_line(max_width,node,indent)
        traverse(max_width,node.left,indent)
   
    traverse(max_width,tree,indent)
    
    
# --------------------------------------------------------------------
# ---- add a node to a tree below the parent
# ---- (call function recursively)
# --------------------------------------------------------------------

def add_node(node,parent):

    if node.value < parent.value:
        if parent.left == None:
            parent.left = node
            node.parent = parent
            node.depth = node_depth(node)
            return
        add_node(node,parent.left)
        return
    
    if parent.right == None:
        parent.right = node
        node.parent  = parent
        node.depth   = node_depth(node)
        return
    add_node(node,parent.right)
    return        


# --------------------------------------------------------------------
# ---- tree statistics - count nodes
# ---- Note: a list is used to hold the counts because it is mutable
# ----       integers, etc. are not
# ---- (setup then call recursive function)
# --------------------------------------------------------------------

def tree_statistics(root):

    def tree_stats(node,stats):
        if node is None:
            stats[1] += 1            # none count
            return
        stats[0] += 1                # node count
        tree_stats(node.left,stats)
        tree_stats(node.right,stats)        
        return

    stats = [0,0]                    # [node_count, none_count]
    tree_stats(root,stats)        
    return (stats[0],stats[1])
        
# --------------------------------------------------------------------
# ---- display tree nodes
# ---- (setup then call recursive function)
# --------------------------------------------------------------------

def list_tree(root):

    def list_nodes(node):
        display_node(node)
        if node.left is not None: list_nodes(node.left)
        if node.right is not None: list_nodes(node.right)

    print()
    ##print('------------------------------------')
    if root is None:
        print('tree is empty')
    else:
        list_nodes(root)
    ##print('------------------------------------')
    return

# --------------------------------------------------------------------
# ---- create a new node - ask the user for its value
# --------------------------------------------------------------------

def new_node():

    node = None

    while True:

        # ---- get node value from the users

        print()
        s = ui.get_user_input('Enter a tree node value [0 to 999]: ')
        if not s: break

        tf,v = ui.is_integer(s)
        if not tf or v < 0 or v > 999:
            print()
            print(f'bad node value input ({s})')
            continue

        # ---- create a new node with value v
        
        node = Node(v)
        break

    return node
        
# --------------------------------------------------------------------
# ---- construct a tree from values in a CSV string
# --------------------------------------------------------------------

def construct_tree_from_csv_string():

    # ---- get cvs string

    print()
    cvs_str = ui.get_user_input('Enter CSV string: ')
    if not cvs_str:
        print()
        print('tree is unchanged')
        return None

    # ---- break  string into list

    lst = cvs_str.replace(',',' ').split()

    for i,s in enumerate(lst):

        tf,v = ui.is_integer(s)

        if not tf or v < -999 or v > 999:
            print()
            print(f'illegal node value is CSV string ({s})')
            print('exit function - nothing modified or created')
            return None

        lst[i] = v

    root = Node(lst[0])

    for v in lst[1:]:
        node = Node(v)
        add_node(node,root)

    return root

# --------------------------------------------------------------------
# ---- main
# --------------------------------------------------------------------

if __name__ == '__main__':

    menu = '''
 option  description
 ------  -------------------------------- 
    1    add node to tree
    2    list tree nodes
    3    tree statistics
    4    new tree (create only root node)
    5    create tree from CSV string

   10    construct balance tree

   20    display node values in order
   
   30    display tree

   99    exit
'''

    root = Node(0)

    while True:

        # --- get and verify user option selection

        print(menu)
        s = ui.get_user_input(' select an option: ')
        if not s: break

        tf,opt = ui.is_int(s)
        if not tf:
            print(f'illegal option ({s}) - try again')
            continue

        # ---- add node to tree
        if opt == 1:            
            node = new_node()
            if Node: add_node(node,root)
            continue

        # ---- list tree nodes
        if opt == 2:
            list_tree(root)
            continue
        
        # ---- display tree statistics
        if opt == 3:
            stats = tree_statistics(root)
            print()
            print(f'node count  = {stats[0]}')
            print(f'none count  = {stats[1]}')
            print(f'tree height = {height(root)}')
            continue

        # ---- delete current tree
        # ---- initialize a new tree root node (value=0)
        if opt == 4:
            root = Node(0)
            continue

        # ---- construct a tree from CSV string values
        if opt == 5:
            x = construct_tree_from_csv_string()
            if x is not None: root = x
            continue

        # ---- construct a balanced tree from current tree
        if opt == 10:
            print()
            print(f'max height before balancing = {height(root)}') 
            node = construct_balanced_tree(root)
            if Node is not None: root = node
            print(f'max height after  balancing = {height(root)}') 
            continue

        # ---- display tree node values in order
        if opt == 20:
            display_node_values(root)
            continue
            
        # ---- display tree node values in order
        if opt == 30:
            display_tree(79,root)
            continue

        # ---- exit program
        if opt == 99:
            break

        print(f'illegal option ({opt}) - try again')