RosettaCodeData/Task/Parametric-polymorphism/Python/parametric-polymorphism-2.py

34 lines
910 B
Python

"""Parametric polymorphism. Requires Python >= 3.12."""
from typing import Callable
from typing import Iterable
class Tree[T]:
def __init__(self, value: T):
self.value = value
self.left: Tree[T] | None = None
self.right: Tree[T] | None = None
def map(self, func: Callable[[T], T]) -> Iterable[T]:
yield func(self.value)
if self.left is not None:
yield from self.left.map(func)
if self.right is not None:
yield from self.right.map(func)
if __name__ == "__main__":
tree = Tree(7)
tree.left = Tree(42)
tree.right = Tree(101)
tree.right.left = Tree(1)
# Fails static type checking as "foo" is not an int.
# tree.left.left = Tree[int]("foo")
# Fails static type checking as Tree[str] is not Tree[int]
# tree.left.left = Tree("bar")
print(list(tree.map(lambda v: v + 1))) # [8, 43, 102, 2]