#########################################################
# Copyright (C) 2022 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Christopher Rodrigues
#########################################################
"""
Construction of compilation pipelines from compilation steps.
"""
from dataclasses import dataclass
from typing import TypeVar, Generic, Callable, Tuple, Optional, Any
_KParam = TypeVar("_KParam", contravariant=True)
_T = TypeVar("_T", covariant=True)
_R = TypeVar("_R", covariant=True)
_A = TypeVar("_A")
[docs]
class CompileStep(Generic[_R]):
"""
A step of processing a model in AFE.
"""
# This is a sum type.
[docs]
def run(self) -> _R:
return _run_compile_step(self, _DoneK())
@staticmethod
[docs]
def pure(value: _A) -> "CompileStep[_A]":
"""
Make a CompileStep that evaluates to the given value and does not execute anything.
"""
return _Pure(value)
[docs]
def then(self, continuation: Callable[[_R], "CompileStep[_T]"]) -> "CompileStep[_T]":
"""
Compose this step with a continuation that processes this step's result.
"""
return _Bind(self, continuation)
# The following functions are implemented using the functions defined above.
@staticmethod
[docs]
def from_thunk(thunk: Callable[[], _T]) -> "CompileStep[_T]":
"""
Make a CompileStep that evaluates the given thunk when it runs.
"""
return CompileStep.pure(()).then(lambda dummy: CompileStep.pure(thunk()))
[docs]
def map(self, f: Callable[[_R], _T]) -> "CompileStep[_T]":
"""
Map a function over the result of this CompileStep.
"""
return self.then(lambda x: CompileStep.pure(f(x)))
@dataclass
class _Pure(CompileStep[_R]):
"""
Produce the value.
"""
value: _R
@dataclass
class _Bind(CompileStep[_R]):
"""
Perform an action, then apply the continuation to its result.
"""
action: CompileStep[_A]
continuation: Callable[[_A], CompileStep[_R]]
def __post_init__(self):
assert isinstance(self.action, CompileStep)
class _Cont(Generic[_KParam, _R]):
"""
A continuation used in the implementation of CompileStep.run.
"""
# This is a sum type
pass
@dataclass
class _DoneK(_Cont[_A, _A]):
"""
Nothing else to do.
"""
pass
@dataclass
class _BindK(_Cont[_KParam, _R]):
"""
Resume a suspended binding.
"""
cont1: Callable[[_KParam], CompileStep[_A]]
continuation: _Cont[_A, _R]
def _enter_compile_step(step: CompileStep[_A], continuation: _Cont[_A, _R]) \
-> Tuple[_A, _Cont[_A, _R]]:
"""
Traverse a CompileStep to find the next runnable continuation.
"""
while True:
if isinstance(step, _Bind):
continuation = _BindK(step.continuation, continuation)
step = step.action
elif isinstance(step, _Pure):
return step.value, continuation
else:
raise TypeError("Invalid CompileStep")
def _run_compile_step(step: CompileStep[_A], continuation: _Cont[_A, _R]) -> _R:
"""
Run a CompileStep. This is the implementation of step.run(context).
:param step: The step to run
:return: The step's output
"""
# This function is CPS-transformed to avoid excessive stack frame growth.
k: _Cont[Any, _R] = continuation
while True:
value, k = _enter_compile_step(step, k)
if isinstance(k, _BindK):
step = k.cont1(value)
k = k.continuation
elif isinstance(k, _DoneK):
return value
else:
raise TypeError("Invalid _Cont")