Source code for afe.ir.execute

#########################################################
# Copyright (C) 2020 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Joey Chou
#########################################################
import numpy as np
from typing import List, Dict, Tuple, Protocol, Callable

from afe.ir.attributes import ExternalAttrs, PlaceholderAttrs, get_data_value_quant_result_scale_with_dummy, \
    PlaceholderQuantAttrs
from afe.ir.defines import Status, InputName, NodeName, DataValue, Quantization, TupleValue, \
    get_expected_tensor_value, get_expected_tuple_values
import afe.ir.operations as operations
from afe.ir.node import AwesomeNode, node_is_sima_ir
from afe.ir.quantization_utils import dequantize_input_dict, quantize_input_dict, quantization_data_value_to_output_list
from afe.ir.sima_ir import SiMaIR, SiMaIRTensorTypes
from afe.ir.tensor_type import TensorType, ScalarType
from afe.core.configs import RunConfigs


def _check_tensor_type(t: TensorType, a: np.ndarray) -> None:
    """
    Verify that t is the type of a.  Raise an exception otherwise.
    """
    assert t.scalar == ScalarType.from_numpy(a.dtype), f'Expected {ScalarType.from_numpy(a.dtype)}, got {t.scalar}'
    assert t.shape == a.shape, f'Expected {a.shape}, got {t.shape}'


def _check_output_type(ir: SiMaIR, output: SiMaIRTensorTypes) -> None:
    """
    Verify that a node's output is consistent with a node's type.
    Raise an exception if it doesn't match.

    :param ir: A model graph node
    :param output: Result of executing the node in ir
    """
    t = ir.get_type().output

    # Currently we only handle single tensor or non-nested tuples and lists
    if isinstance(output, np.ndarray):
        _check_tensor_type(get_expected_tensor_value(t), output)
    elif isinstance(output, (Tuple, List)):
        _types = get_expected_tuple_values(t)
        assert len(_types) == len(output)
        for t_item, o_item in zip(_types, output):
            _check_tensor_type(t_item, o_item)
    else:
        raise TypeError("Invalid DataValue instance")


def _check_input_shape(ir: SiMaIR, inputs: Dict[InputName, SiMaIRTensorTypes]) -> None:
    """
    Verify that input have good shape. Raise an exception otherwise.
    :param ir: A model graph node
    :param inputs: Inputs of the node
    """
    if isinstance(ir.attrs, PlaceholderAttrs) or isinstance(ir.quant_attrs, PlaceholderQuantAttrs):
        return  # PlaceholderOp does not have inputs
    else:
        t = ir.get_type().inputs

    assert list(t.keys()) == list(inputs.keys())  # Use list comparison to check the order of keys

    for ten, inp in zip(t.values(), inputs.values()):
        if isinstance(ten, TupleValue):
            assert len(ten.elements) == len(inp)
            for t_shape, input_shape in zip(ten.elements, inp):
                assert t_shape.value.shape == input_shape.shape
        else:
            t_shape = ten.value.shape
            input_shape = np.asarray(inp).shape

            assert t_shape == input_shape, f"Expected {input_shape}, got {t_shape}"


#############################
# Functions to execute SimaIR
#############################
[docs] def execute_ir(ir: SiMaIR, inputs: Dict[InputName, SiMaIRTensorTypes], config: RunConfigs) -> SiMaIRTensorTypes: """ Execute the SiMaIR of the non-quantized node. :param ir: SiMaIR. The IR of the node to be executed. :param inputs: Dict[InputName, SiMaIRTensorTypes]. Inputs of the node. :param config: Configuration for how to execute a node. :return: SiMaIRTensorTypes. Results obtained by running the SiMaIR. """ _check_input_shape(ir, inputs) output = ir.run(inputs, config) _check_output_type(ir, output) return output
[docs] def execute_ir_quant(ir: SiMaIR, inputs: Dict[InputName, SiMaIRTensorTypes], config: RunConfigs) -> SiMaIRTensorTypes: """ Execute the SiMaIR of the quantized node. :param ir: SiMaIR. The IR of the node to be executed. :param inputs: Dict[InputName, SiMaIRTensorTypes]. Inputs of the node. :param config: Configuration for how to execute a node. :return: SiMaIRTensorTypes. Results obtained by running the SiMaIR. """ _check_input_shape(ir, inputs) output = ir.run_quant(inputs, config) _check_output_type(ir, output) return output
################################## # Functions to execute AwesomeNode ##################################
[docs] class NodeExecutor(Protocol): """A callable object that executes a node as part of AwesomeNet.run.""" def __call__(self, node: AwesomeNode, inputs: Dict[InputName, SiMaIRTensorTypes], node_outputs: Dict[NodeName, SiMaIRTensorTypes]) -> None: """ Execute a non-quantized AwesomeNode, using the default configuration. The node's output is inserted into node_outputs. :param node: AwesomeNode. The node to be executed. :param inputs: Dict[InputName, SiMaIRTensorTypes]. The inputs to the node. :param node_outputs: Dict[NodeName, SiMaIRTensorTypes]. Dictionary holding outputs of the nodes. """ ...
[docs] def node_executor(config: RunConfigs) -> NodeExecutor: """ Create an executor for a non-quantized AwesomeNode. :param config: Configuration for how to execute a node. :return: Executor. """ def do_execute(node: AwesomeNode, inputs: Dict[InputName, SiMaIRTensorTypes], node_outputs: Dict[NodeName, SiMaIRTensorTypes]): outputs = execute_ir(node.ir, inputs, config) node_outputs[node.name] = outputs return do_execute
[docs] def node_quant_executor(config: RunConfigs) -> NodeExecutor: """ Create an executor for a quantized AwesomeNode. :param config: Configuration for how to execute a node. :return: Executor. The executor takes the same parameters as execute_node. """ def do_execute(node: AwesomeNode, inputs: Dict[InputName, SiMaIRTensorTypes], node_outputs: Dict[NodeName, SiMaIRTensorTypes]): if node.ir.attrs is not None: # This node has not been transformed by quantization. # Execute node as if it's a floating-point node. run_func = execute_ir else: run_func = execute_ir_quant outputs = run_func(node.ir, inputs, config) node_outputs[node.name] = outputs return do_execute
[docs] def create_node_executor(fast_mode: bool): return node_executor(RunConfigs(fast_mode=fast_mode))
[docs] def create_node_quant_executor(fast_mode: bool) -> NodeExecutor: return node_quant_executor(RunConfigs(fast_mode=fast_mode))
# Execute a non-quantized AwesomeNode using the default configuration.
[docs] execute_node: NodeExecutor = create_node_executor(fast_mode=True)
# Execute a quantized AwesomeNode using the default configuration.
[docs] execute_node_quant: NodeExecutor = create_node_quant_executor(fast_mode=True)