Source code for afe.driver.compile_step

#########################################################
# Copyright (C) 2022 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Christopher Rodrigues
#########################################################
"""
Construction of compilation pipelines from compilation steps.
"""
from dataclasses import dataclass
from typing import TypeVar, Generic, Callable, Tuple, Optional, Any

_KParam = TypeVar("_KParam", contravariant=True)
_T = TypeVar("_T", covariant=True)
_R = TypeVar("_R", covariant=True)
_A = TypeVar("_A")


[docs] class CompileStep(Generic[_R]): """ A step of processing a model in AFE. """ # This is a sum type.
[docs] def run(self) -> _R: return _run_compile_step(self, _DoneK())
@staticmethod
[docs] def pure(value: _A) -> "CompileStep[_A]": """ Make a CompileStep that evaluates to the given value and does not execute anything. """ return _Pure(value)
[docs] def then(self, continuation: Callable[[_R], "CompileStep[_T]"]) -> "CompileStep[_T]": """ Compose this step with a continuation that processes this step's result. """ return _Bind(self, continuation)
# The following functions are implemented using the functions defined above. @staticmethod
[docs] def from_thunk(thunk: Callable[[], _T]) -> "CompileStep[_T]": """ Make a CompileStep that evaluates the given thunk when it runs. """ return CompileStep.pure(()).then(lambda dummy: CompileStep.pure(thunk()))
[docs] def map(self, f: Callable[[_R], _T]) -> "CompileStep[_T]": """ Map a function over the result of this CompileStep. """ return self.then(lambda x: CompileStep.pure(f(x)))
@dataclass class _Pure(CompileStep[_R]): """ Produce the value. """ value: _R @dataclass class _Bind(CompileStep[_R]): """ Perform an action, then apply the continuation to its result. """ action: CompileStep[_A] continuation: Callable[[_A], CompileStep[_R]] def __post_init__(self): assert isinstance(self.action, CompileStep) class _Cont(Generic[_KParam, _R]): """ A continuation used in the implementation of CompileStep.run. """ # This is a sum type pass @dataclass class _DoneK(_Cont[_A, _A]): """ Nothing else to do. """ pass @dataclass class _BindK(_Cont[_KParam, _R]): """ Resume a suspended binding. """ cont1: Callable[[_KParam], CompileStep[_A]] continuation: _Cont[_A, _R] def _enter_compile_step(step: CompileStep[_A], continuation: _Cont[_A, _R]) \ -> Tuple[_A, _Cont[_A, _R]]: """ Traverse a CompileStep to find the next runnable continuation. """ while True: if isinstance(step, _Bind): continuation = _BindK(step.continuation, continuation) step = step.action elif isinstance(step, _Pure): return step.value, continuation else: raise TypeError("Invalid CompileStep") def _run_compile_step(step: CompileStep[_A], continuation: _Cont[_A, _R]) -> _R: """ Run a CompileStep. This is the implementation of step.run(context). :param step: The step to run :return: The step's output """ # This function is CPS-transformed to avoid excessive stack frame growth. k: _Cont[Any, _R] = continuation while True: value, k = _enter_compile_step(step, k) if isinstance(k, _BindK): step = k.cont1(value) k = k.continuation elif isinstance(k, _DoneK): return value else: raise TypeError("Invalid _Cont")