233 lines
5.6 KiB
Python
233 lines
5.6 KiB
Python
'''Tree traversals'''
|
|
|
|
from itertools import (chain, takewhile)
|
|
from functools import (reduce)
|
|
from operator import (mul)
|
|
|
|
|
|
# foldTree :: (a -> [b] -> b) -> Tree a -> b
|
|
def foldTree(f):
|
|
'''The catamorphism on trees.
|
|
A summary value derived by a depth-first fold.'''
|
|
def go(node):
|
|
return f(root(node))(
|
|
list(map(go, nest(node)))
|
|
)
|
|
return lambda tree: go(tree)
|
|
|
|
|
|
# levels :: Tree a -> [[a]]
|
|
def levels(tree):
|
|
'''A list of the nodes at each level of the tree.'''
|
|
fmap = curry(map)
|
|
return list(fmap(fmap(root))(
|
|
takewhile(
|
|
bool,
|
|
iterate(concatMap(nest))([tree])
|
|
)
|
|
))
|
|
|
|
|
|
# preorder :: a -> [[a]] -> [a]
|
|
def preorder(x):
|
|
'''This node followed by the rest.'''
|
|
return lambda xs: [x] + concat(xs)
|
|
|
|
|
|
# inorder :: a -> [[a]] -> [a]
|
|
def inorder(x):
|
|
'''Descendants of any first child,
|
|
then this node, then the rest.'''
|
|
return lambda xs: (
|
|
xs[0] + [x] + concat(xs[1:]) if xs else [x]
|
|
)
|
|
|
|
|
|
# postorder :: a -> [[a]] -> [a]
|
|
def postorder(x):
|
|
'''Descendants first, then this node.'''
|
|
return lambda xs: concat(xs) + [x]
|
|
|
|
|
|
# levelorder :: Tree a -> [a]
|
|
def levelorder(tree):
|
|
'''Top-down concatenation of this node
|
|
with the rows below.'''
|
|
return concat(levels(tree))
|
|
|
|
|
|
# treeSum :: Tree Int -> Int
|
|
def treeSum(x):
|
|
'''This node's value + the sum of its descendants.'''
|
|
return lambda xs: x + sum(xs)
|
|
|
|
|
|
# treeSum :: Tree Int -> Int
|
|
def treeProduct(x):
|
|
'''This node's value * the product of its descendants.'''
|
|
return lambda xs: x * numericProduct(xs)
|
|
|
|
|
|
# treeMax :: Tree Int -> Int
|
|
def treeMax(x):
|
|
'''Maximum value of this node and any descendants.'''
|
|
return lambda xs: max([x] + xs)
|
|
|
|
|
|
# treeMin :: Tree Int -> Int
|
|
def treeMin(x):
|
|
'''Minimum value of this node and any descendants.'''
|
|
return lambda xs: min([x] + xs)
|
|
|
|
|
|
# nodeCount :: Tree a -> Int
|
|
def nodeCount(_):
|
|
'''One more than the total number of descendants.'''
|
|
return lambda xs: 1 + sum(xs)
|
|
|
|
|
|
# treeWidth :: Tree a -> Int
|
|
def treeWidth(_):
|
|
'''Sum of widths of any children, or a minimum of 1.'''
|
|
return lambda xs: sum(xs) if xs else 1
|
|
|
|
|
|
# treeDepth :: Tree a -> Int
|
|
def treeDepth(_):
|
|
'''One more than that of the deepest child.'''
|
|
return lambda xs: 1 + (max(xs) if xs else 0)
|
|
|
|
|
|
# TEST ----------------------------------------------------
|
|
# main :: IO ()
|
|
def main():
|
|
'''Tree traversals - accumulating and folding'''
|
|
|
|
# tree :: Tree Int
|
|
tree = Node(1)([
|
|
Node(2)([
|
|
Node(4)([
|
|
Node(7)([])
|
|
]),
|
|
Node(5)([])
|
|
]),
|
|
Node(3)([
|
|
Node(6)([
|
|
Node(8)([]),
|
|
Node(9)([])
|
|
])
|
|
])
|
|
])
|
|
|
|
print(
|
|
fTable(main.__doc__ + ':\n')(fName)(str)(
|
|
lambda f: (
|
|
foldTree(f) if 'levelorder' != fName(f) else f
|
|
)(tree)
|
|
)([
|
|
preorder, inorder, postorder, levelorder,
|
|
treeSum, treeProduct, treeMin, treeMax,
|
|
nodeCount, treeWidth, treeDepth
|
|
])
|
|
)
|
|
|
|
|
|
# GENERIC -------------------------------------------------
|
|
|
|
# Node :: a -> [Tree a] -> Tree a
|
|
def Node(v):
|
|
'''Contructor for a Tree node which connects a
|
|
value of some kind to a list of zero or
|
|
more child trees.'''
|
|
return lambda xs: {'type': 'Node', 'root': v, 'nest': xs}
|
|
|
|
|
|
# nest :: Tree a -> [Tree a]
|
|
def nest(tree):
|
|
'''Accessor function for children of tree node'''
|
|
return tree['nest'] if 'nest' in tree else None
|
|
|
|
|
|
# root :: Dict -> a
|
|
def root(tree):
|
|
'''Accessor function for data of tree node'''
|
|
return tree['root'] if 'root' in tree else None
|
|
|
|
|
|
# concat :: [[a]] -> [a]
|
|
# concat :: [String] -> String
|
|
def concat(xxs):
|
|
'''The concatenation of all the elements in a list.'''
|
|
xs = list(chain.from_iterable(xxs))
|
|
unit = '' if isinstance(xs, str) else []
|
|
return unit if not xs else (
|
|
''.join(xs) if isinstance(xs[0], str) else xs
|
|
)
|
|
|
|
|
|
# concatMap :: (a -> [b]) -> [a] -> [b]
|
|
def concatMap(f):
|
|
'''Concatenated list over which a function has been mapped.
|
|
The list monad can be derived by using a function f which
|
|
wraps its output a in list,
|
|
(using an empty list to represent computational failure).'''
|
|
return lambda xs: list(
|
|
chain.from_iterable(
|
|
map(f, xs)
|
|
)
|
|
)
|
|
|
|
|
|
# curry :: ((a, b) -> c) -> a -> b -> c
|
|
def curry(f):
|
|
'''A curried function derived
|
|
from an uncurried function.'''
|
|
return lambda a: lambda b: f(a, b)
|
|
|
|
|
|
# iterate :: (a -> a) -> a -> Gen [a]
|
|
def iterate(f):
|
|
'''An infinite list of repeated applications of f to x.'''
|
|
def go(x):
|
|
v = x
|
|
while True:
|
|
yield v
|
|
v = f(v)
|
|
return lambda x: go(x)
|
|
|
|
|
|
# numericProduct :: [Num] -> Num
|
|
def numericProduct(xs):
|
|
'''The arithmetic product of all numbers in xs.'''
|
|
return reduce(mul, xs, 1)
|
|
|
|
|
|
# FORMATTING ----------------------------------------------
|
|
|
|
# fName :: (a -> b) -> String
|
|
def fName(f):
|
|
'''The name bound to the function.'''
|
|
return f.__name__
|
|
|
|
|
|
# fTable :: String -> (a -> String) ->
|
|
# (b -> String) -> (a -> b) -> [a] -> String
|
|
def fTable(s):
|
|
'''Heading -> x display function -> fx display function ->
|
|
f -> xs -> tabular string.
|
|
'''
|
|
def go(xShow, fxShow, f, xs):
|
|
ys = [xShow(x) for x in xs]
|
|
w = max(map(len, ys))
|
|
return s + '\n' + '\n'.join(map(
|
|
lambda x, y: y.rjust(w, ' ') + ' -> ' + fxShow(f(x)),
|
|
xs, ys
|
|
))
|
|
return lambda xShow: lambda fxShow: lambda f: lambda xs: go(
|
|
xShow, fxShow, f, xs
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|