212 lines
5.0 KiB
Python
212 lines
5.0 KiB
Python
'''Tree traversals'''
|
|
|
|
from itertools import chain
|
|
from functools import reduce
|
|
from operator import mul
|
|
|
|
|
|
# foldTree :: (a -> [b] -> b) -> Tree a -> b
|
|
def foldTree(f):
|
|
'''The catamorphism on trees. A summary
|
|
value defined by a depth-first fold.
|
|
'''
|
|
def go(node):
|
|
return f(root(node))([
|
|
go(x) for x in nest(node)
|
|
])
|
|
return go
|
|
|
|
|
|
# levels :: Tree a -> [[a]]
|
|
def levels(tree):
|
|
'''A list of lists, grouping the root
|
|
values of each level of the tree.
|
|
'''
|
|
def go(a, node):
|
|
h, *t = a if a else ([], [])
|
|
|
|
return [[root(node)] + h] + reduce(
|
|
go, nest(node)[::-1], t
|
|
)
|
|
return go([], tree)
|
|
|
|
|
|
# preorder :: a -> [[a]] -> [a]
|
|
def preorder(x):
|
|
'''This node followed by the rest.'''
|
|
return lambda xs: [x] + concat(xs)
|
|
|
|
|
|
# inorder :: a -> [[a]] -> [a]
|
|
def inorder(x):
|
|
'''Descendants of any first child,
|
|
then this node, then the rest.'''
|
|
return lambda xs: (
|
|
xs[0] + [x] + concat(xs[1:]) if xs else [x]
|
|
)
|
|
|
|
|
|
# postorder :: a -> [[a]] -> [a]
|
|
def postorder(x):
|
|
'''Descendants first, then this node.'''
|
|
return lambda xs: concat(xs) + [x]
|
|
|
|
|
|
# levelorder :: Tree a -> [a]
|
|
def levelorder(tree):
|
|
'''Top-down concatenation of this node
|
|
with the rows below.'''
|
|
return concat(levels(tree))
|
|
|
|
|
|
# treeSum :: Int -> [Int] -> Int
|
|
def treeSum(x):
|
|
'''This node's value + the sum of its descendants.'''
|
|
return lambda xs: x + sum(xs)
|
|
|
|
|
|
# treeProduct :: Int -> [Int] -> Int
|
|
def treeProduct(x):
|
|
'''This node's value * the product of its descendants.'''
|
|
return lambda xs: x * numericProduct(xs)
|
|
|
|
|
|
# treeMax :: Ord a => a -> [a] -> a
|
|
def treeMax(x):
|
|
'''Maximum value of this node and any descendants.'''
|
|
return lambda xs: max([x] + xs)
|
|
|
|
|
|
# treeMin :: Ord a => a -> [a] -> a
|
|
def treeMin(x):
|
|
'''Minimum value of this node and any descendants.'''
|
|
return lambda xs: min([x] + xs)
|
|
|
|
|
|
# nodeCount :: Int -> [Int] -> Int
|
|
def nodeCount(_):
|
|
'''One more than the total number of descendants.'''
|
|
return lambda xs: 1 + sum(xs)
|
|
|
|
|
|
# treeWidth :: Int -> [Int] -> Int
|
|
def treeWidth(_):
|
|
'''Sum of widths of any children, or a minimum of 1.'''
|
|
return lambda xs: sum(xs) if xs else 1
|
|
|
|
|
|
# treeDepth :: Int -> [Int] -> Int
|
|
def treeDepth(_):
|
|
'''One more than that of the deepest child.'''
|
|
return lambda xs: 1 + (max(xs) if xs else 0)
|
|
|
|
|
|
# ------------------------- TEST -------------------------
|
|
# main :: IO ()
|
|
def main():
|
|
'''Tree traversals - accumulating and folding'''
|
|
|
|
# tree :: Tree Int
|
|
tree = Node(1)([
|
|
Node(2)([
|
|
Node(4)([
|
|
Node(7)([])
|
|
]),
|
|
Node(5)([])
|
|
]),
|
|
Node(3)([
|
|
Node(6)([
|
|
Node(8)([]),
|
|
Node(9)([])
|
|
])
|
|
])
|
|
])
|
|
|
|
print(
|
|
fTable(main.__doc__ + ':\n')(fName)(str)(
|
|
lambda f: (
|
|
foldTree(f) if 'levelorder' != fName(f) else f
|
|
)(tree)
|
|
)([
|
|
preorder, inorder, postorder, levelorder,
|
|
treeSum, treeProduct, treeMin, treeMax,
|
|
nodeCount, treeWidth, treeDepth
|
|
])
|
|
)
|
|
|
|
|
|
# ----------------------- GENERIC ------------------------
|
|
|
|
# Node :: a -> [Tree a] -> Tree a
|
|
def Node(v):
|
|
'''Contructor for a Tree node which connects a
|
|
value of some kind to a list of zero or
|
|
more child trees.'''
|
|
return lambda xs: {
|
|
'type': 'Node', 'root': v, 'nest': xs
|
|
}
|
|
|
|
|
|
# nest :: Tree a -> [Tree a]
|
|
def nest(tree):
|
|
'''Accessor function for children of tree node'''
|
|
return tree['nest'] if 'nest' in tree else None
|
|
|
|
|
|
# root :: Dict -> a
|
|
def root(tree):
|
|
'''Accessor function for data of tree node'''
|
|
return tree['root'] if 'root' in tree else None
|
|
|
|
|
|
# concat :: [[a]] -> [a]
|
|
# concat :: [String] -> String
|
|
def concat(xxs):
|
|
'''The concatenation of all the elements in a list.'''
|
|
xs = list(chain.from_iterable(xxs))
|
|
unit = '' if isinstance(xs, str) else []
|
|
return unit if not xs else (
|
|
''.join(xs) if isinstance(xs[0], str) else xs
|
|
)
|
|
|
|
|
|
# numericProduct :: [Num] -> Num
|
|
def numericProduct(xs):
|
|
'''The arithmetic product of all numbers in xs.'''
|
|
return reduce(mul, xs, 1)
|
|
|
|
|
|
# ---------------------- FORMATTING ----------------------
|
|
|
|
# fName :: (a -> b) -> String
|
|
def fName(f):
|
|
'''The name bound to the function.'''
|
|
return f.__name__
|
|
|
|
|
|
# fTable :: String -> (a -> String) ->
|
|
# (b -> String) ->
|
|
# (a -> b) -> [a] -> String
|
|
def fTable(s):
|
|
'''Heading -> x display function ->
|
|
fx display function -> f -> xs -> tabular string.
|
|
'''
|
|
def go(xShow, fxShow, f, xs):
|
|
ys = [xShow(x) for x in xs]
|
|
w = max(map(len, ys))
|
|
return s + '\n' + '\n'.join(map(
|
|
lambda x, y: y.rjust(w, ' ') + (
|
|
' -> ' + fxShow(f(x))
|
|
),
|
|
xs, ys
|
|
))
|
|
return lambda xShow: lambda fxShow: (
|
|
lambda f: lambda xs: go(
|
|
xShow, fxShow, f, xs
|
|
)
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|