#########################################################
# Copyright (C) 2020 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Joey Chou
#########################################################
import numpy as np
from typing import List, Dict, Tuple, Protocol, Callable
from afe.ir.attributes import ExternalAttrs, PlaceholderAttrs, get_data_value_quant_result_scale_with_dummy, \
PlaceholderQuantAttrs
from afe.ir.defines import Status, InputName, NodeName, DataValue, Quantization, TupleValue, \
get_expected_tensor_value, get_expected_tuple_values
import afe.ir.operations as operations
from afe.ir.node import AwesomeNode, node_is_sima_ir
from afe.ir.quantization_utils import dequantize_input_dict, quantize_input_dict, quantization_data_value_to_output_list
from afe.ir.sima_ir import SiMaIR, SiMaIRTensorTypes
from afe.ir.tensor_type import TensorType, ScalarType
from afe.core.configs import RunConfigs
def _check_tensor_type(t: TensorType, a: np.ndarray) -> None:
"""
Verify that t is the type of a. Raise an exception otherwise.
"""
assert t.scalar == ScalarType.from_numpy(a.dtype), f'Expected {ScalarType.from_numpy(a.dtype)}, got {t.scalar}'
assert t.shape == a.shape, f'Expected {a.shape}, got {t.shape}'
def _check_output_type(ir: SiMaIR, output: SiMaIRTensorTypes) -> None:
"""
Verify that a node's output is consistent with a node's type.
Raise an exception if it doesn't match.
:param ir: A model graph node
:param output: Result of executing the node in ir
"""
t = ir.get_type().output
# Currently we only handle single tensor or non-nested tuples and lists
if isinstance(output, np.ndarray):
_check_tensor_type(get_expected_tensor_value(t), output)
elif isinstance(output, (Tuple, List)):
_types = get_expected_tuple_values(t)
assert len(_types) == len(output)
for t_item, o_item in zip(_types, output):
_check_tensor_type(t_item, o_item)
else:
raise TypeError("Invalid DataValue instance")
def _check_input_shape(ir: SiMaIR, inputs: Dict[InputName, SiMaIRTensorTypes]) -> None:
"""
Verify that input have good shape. Raise an exception otherwise.
:param ir: A model graph node
:param inputs: Inputs of the node
"""
if isinstance(ir.attrs, PlaceholderAttrs) or isinstance(ir.quant_attrs, PlaceholderQuantAttrs):
return # PlaceholderOp does not have inputs
else:
t = ir.get_type().inputs
assert list(t.keys()) == list(inputs.keys()) # Use list comparison to check the order of keys
for ten, inp in zip(t.values(), inputs.values()):
if isinstance(ten, TupleValue):
assert len(ten.elements) == len(inp)
for t_shape, input_shape in zip(ten.elements, inp):
assert t_shape.value.shape == input_shape.shape
else:
t_shape = ten.value.shape
input_shape = np.asarray(inp).shape
assert t_shape == input_shape, f"Expected {input_shape}, got {t_shape}"
#############################
# Functions to execute SimaIR
#############################
[docs]
def execute_ir(ir: SiMaIR, inputs: Dict[InputName, SiMaIRTensorTypes], config: RunConfigs) -> SiMaIRTensorTypes:
"""
Execute the SiMaIR of the non-quantized node.
:param ir: SiMaIR. The IR of the node to be executed.
:param inputs: Dict[InputName, SiMaIRTensorTypes]. Inputs of the node.
:param config: Configuration for how to execute a node.
:return: SiMaIRTensorTypes. Results obtained by running the SiMaIR.
"""
_check_input_shape(ir, inputs)
output = ir.run(inputs, config)
_check_output_type(ir, output)
return output
[docs]
def execute_ir_quant(ir: SiMaIR, inputs: Dict[InputName, SiMaIRTensorTypes], config: RunConfigs) -> SiMaIRTensorTypes:
"""
Execute the SiMaIR of the quantized node.
:param ir: SiMaIR. The IR of the node to be executed.
:param inputs: Dict[InputName, SiMaIRTensorTypes]. Inputs of the node.
:param config: Configuration for how to execute a node.
:return: SiMaIRTensorTypes. Results obtained by running the SiMaIR.
"""
_check_input_shape(ir, inputs)
output = ir.run_quant(inputs, config)
_check_output_type(ir, output)
return output
##################################
# Functions to execute AwesomeNode
##################################
[docs]
class NodeExecutor(Protocol):
"""A callable object that executes a node as part of AwesomeNet.run."""
def __call__(self, node: AwesomeNode, inputs: Dict[InputName, SiMaIRTensorTypes],
node_outputs: Dict[NodeName, SiMaIRTensorTypes]) -> None:
"""
Execute a non-quantized AwesomeNode, using the default configuration.
The node's output is inserted into node_outputs.
:param node: AwesomeNode. The node to be executed.
:param inputs: Dict[InputName, SiMaIRTensorTypes]. The inputs to the node.
:param node_outputs: Dict[NodeName, SiMaIRTensorTypes]. Dictionary holding outputs of the nodes.
"""
...
[docs]
def node_executor(config: RunConfigs) -> NodeExecutor:
"""
Create an executor for a non-quantized AwesomeNode.
:param config: Configuration for how to execute a node.
:return: Executor.
"""
def do_execute(node: AwesomeNode, inputs: Dict[InputName, SiMaIRTensorTypes],
node_outputs: Dict[NodeName, SiMaIRTensorTypes]):
outputs = execute_ir(node.ir, inputs, config)
node_outputs[node.name] = outputs
return do_execute
[docs]
def node_quant_executor(config: RunConfigs) -> NodeExecutor:
"""
Create an executor for a quantized AwesomeNode.
:param config: Configuration for how to execute a node.
:return: Executor. The executor takes the same parameters as execute_node.
"""
def do_execute(node: AwesomeNode, inputs: Dict[InputName, SiMaIRTensorTypes],
node_outputs: Dict[NodeName, SiMaIRTensorTypes]):
if node.ir.attrs is not None:
# This node has not been transformed by quantization.
# Execute node as if it's a floating-point node.
run_func = execute_ir
else:
run_func = execute_ir_quant
outputs = run_func(node.ir, inputs, config)
node_outputs[node.name] = outputs
return do_execute
[docs]
def create_node_executor(fast_mode: bool):
return node_executor(RunConfigs(fast_mode=fast_mode))
[docs]
def create_node_quant_executor(fast_mode: bool) -> NodeExecutor:
return node_quant_executor(RunConfigs(fast_mode=fast_mode))
# Execute a non-quantized AwesomeNode using the default configuration.
[docs]
execute_node: NodeExecutor = create_node_executor(fast_mode=True)
# Execute a quantized AwesomeNode using the default configuration.
[docs]
execute_node_quant: NodeExecutor = create_node_quant_executor(fast_mode=True)