#########################################################
# Copyright (C) 2020 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Joey Chou
#########################################################
import copy
import dataclasses
from dataclasses import dataclass
import numpy as np
from termcolor import colored
from typing import Dict, Callable, Optional, Any, Tuple, TypeVar, Set, cast, List, Iterable
import afe.ir.attributes as attributes
import afe.ir.operations as operations
from afe.ir.execute import create_node_quant_executor
from afe.backends import Backend
from afe.core.configs import QuantizationConfigs, update_quantization_configs, QuantizationPrecision, Opt, \
merge_quantization_configs
from afe.ir.build_node import create_tuple_output_node, create_tuple_get_item_nodes, create_quantization_node, \
create_dequantization_node, create_requantization_node, create_cast_node
from afe.ir.quantization_utils import create_requantization_from_cast, cast_calibration_inputs
from afe.ir.defines import (
NodeName, Status, DataValue, InputName, TensorValue, IdentityCast, QuantCast, DequantCast,
InputsQuantCast, QuantizationCasts, QuantizationCast, Quantization, map_data_value, TupleCast, TupleValue,
get_expected_tensor_value, RequantCast, RequantMethod, LogNodeReporter, RequantizationMode, ConvertCast, BiasCorrectionType
)
from afe.ir.net import AwesomeNet, update_awesomenet_status
from afe.ir.node import AwesomeNode, node_is_awesomenet, node_is_external, node_is_sima_ir, node_is_backend_ir
from afe.ir.sima_ir import SiMaIR
from afe.ir.tensor_type import ScalarType, TensorType, scalar_byte_size
from afe.ir.transform.base_transform import BaseTransform
################################
# Node to node mappings
################################
[docs]
class UpdateQuantizationConfigs(BaseTransform):
"""
Update quantization configs to each node in AwesomeNet
:attribute asymmetry: bool. Set to True if users want to do asymmetry quantization for all layers
or False for symmetric quantization unless some layers are specific configured in
the custom_quantization_configs.
:attribute per_channel: bool. Set to True if users want to do per-channel quantization for all layers
or False for per-tensor quantization unless some layers are specific configured in
the custom_quantization_configs.
:attribute leaky_relu_uses_udf: Default is True. If Ture, AFE will use UDF for LeakyRelu operation. Else
AFE will breakdown the LeakyRelu into multiple elementwise operationis unless some layers are specific
configured in the custom_quantization_configs.
:attribute custom_quantization_configs: Optional[Dict[NodeName, Dict[str, Any]]]. A dictionary using NodeName
as keys. The value to each key is a dictionary of the AwesomeQuantAttr's field names and sets target
configuration.
Example
-------
The example shows how a custom_quantization_configs looks like to config the output_int32 field in a
Conv2DQuantAttrs in a output conv2d_add node to True.
custom_quantization_configs = {"MLA_1/conv2d_add_84": {"output_int32": True}}
"""
_quant_configs: QuantizationConfigs
_custom_quantization_configs: Optional[Dict[NodeName, Dict[str, Any]]]
def __init__(self, quant_configs: QuantizationConfigs,
custom_quantization_configs: Optional[Dict[NodeName, Dict[str, Any]]] = None):
self._quant_configs = quant_configs
self._custom_quantization_configs = custom_quantization_configs
def __call__(self, net: AwesomeNet) -> None:
for node in net.iter_nodes_recursive():
if node_is_sima_ir(node):
# Initialize the unassigned configuration fields, using quant_configs.
# Do not override assigned fields.
c = merge_quantization_configs(config1=node.ir.quant_config, config2=self._quant_configs)
# Override fields using custom_quantization_configs.
override_dict = None if self._custom_quantization_configs is None \
else self._custom_quantization_configs.get(node.name)
if override_dict is not None:
override_c = QuantizationConfigs(**{k: Opt(v) for k, v in override_dict.items()})
c = merge_quantization_configs(config1=override_c, config2=c)
node.ir.quant_config = c
elif node_is_backend_ir(node) or node_is_awesomenet(node):
pass
else:
raise TypeError("Unknown node type")
[docs]
def add_external_nodes_to_float_node_list(net: AwesomeNet) -> None:
"""
Adding names of external nodes to the net.float_node_list attributes.
net.float_node_list is used to decide whether to apply quantization transforms on a node.
Parameters
----------
:param net: AwesomeNet.
"""
for node in net.nodes.values():
if node_is_awesomenet(node):
add_external_nodes_to_float_node_list(node.ir)
elif node_is_external(node):
net.extend_float_node_list([node.name])
_A = TypeVar("_A")
_State = TypeVar("_State")
def _get_subnet_nodes_input_names(net: AwesomeNet) -> List:
"""
Get input node names of a net's nodes.
"""
input_node_names = []
if net.backend == Backend.MLA:
for node_name in net.execution_order:
node = net.nodes[node_name]
assert isinstance(node.ir, SiMaIR)
input_node_names.extend(node.input_node_names if not isinstance(node.ir.operation, operations.PlaceholderOp)
else [])
return input_node_names
[docs]
def traverse_network(visit_node: Callable[[_State, Dict[InputName, _A], Optional[_A], AwesomeNode], _A],
state: _State,
inputs: Dict[InputName, _A],
net: AwesomeNet) -> _A:
"""
Traverse a network in topological order and process each node using visit_function.
Values are propagated along graph edges.
The network's topological order must have been computed. The topological order
upon entry to the function is used.
:param visit_node: Processing to do on a network node. Processing can access
shared state and can read results from processing the node's inputs. Processing
returns a result to be passed to other nodes.
It may modify the node and the network.
:param net: Network to traverse.
:param state: Mutable data used by visit_node.
:param inputs: An ordered dict of values for the net's inputs.
:return: Result of processing the network's output node.
"""
assert net.input_node_names == list(inputs.keys()), "Input keys do not match net's inputs"
subnet_nodes_input_names = _get_subnet_nodes_input_names(net)
# Visit nodes in topological order
node_values: Dict[NodeName, _A] = {}
for node_name in net.execution_order.copy():
node = net.nodes[node_name]
if isinstance(node.ir, SiMaIR) and isinstance(node.ir.operation, operations.PlaceholderOp):
node_inputs = {} # Special handling because of how PlaceholderOp.input_names works
placeholder_value = inputs[cast(InputName, node_name)]
else:
node_inputs = {parameter: node_values[argument]
for parameter, argument in zip(node.input_names, node.input_node_names)}
placeholder_value = None
node_value = visit_node(state, node_inputs, placeholder_value, node)
node_values[node_name] = node_value
net._prune_unneeded_activations_from_memory(node_values, node_name, protected_names=net.output_node_name)
return node_values[net.output_node_name]
@dataclass
[docs]
class QuantizeState:
"""Data used by the quantization algorithm during traversal of a network's nodes."""
# Casts that are produced during quantization.
# They are saved here while quantizing nodes.
# After all nodes are quantized, the casts are inserted.
[docs]
casts: QuantizationCasts
# Names of nodes for which quantization is disallowed at the current subgraph.
# It is used by graph analyzer to measure the effect of not quantizing a node.
[docs]
disallowed_quantization_nodes: Set[NodeName]
@classmethod
[docs]
def initialize(cls, num_calibration_inputs: int) -> "QuantizeState":
"""
Initialize a new QuantizeState for the beginning of the algorithm.
"""
return QuantizeState(QuantizationCasts(), set(), num_calibration_inputs)
[docs]
def enter_awesome_net_scope(self, net: AwesomeNet, num_calibration_inputs: int) -> "QuantizeState":
"""
Prepare to process the nodes in net. Returns a new QuantizeState with
information about the nodes in net. Some mutable data is shared with
the original QuantizeState.
"""
# Share the mutable data in 'casts'.
# Use the subgraph's float node list.
return QuantizeState(self.casts, set(net.float_node_list), num_calibration_inputs)
def _compute_placeholder_quantization(placeholder_node: AwesomeNode) -> Quantization:
"""
Compute the quantization of a placeholder node based on the result of calibration.
"""
assert isinstance(placeholder_node.ir, SiMaIR)
assert isinstance(placeholder_node.ir.operation, operations.PlaceholderOp)
assert placeholder_node.ir.calib_attrs.observer is not None
qrange = None
# If int16 quantization is enabled, specify int16 range
if placeholder_node.ir.quant_config.quantization_precision.get().is_int16_precision():
asymmetry = placeholder_node.ir.quant_config.asymmetry.get()
qrange = (-32767, 32767) if not asymmetry else (-32768, 32767)
quant = placeholder_node.ir.calib_attrs.observer.calculate_quantization(qrange)
assert isinstance(quant, TensorValue)
return quant.value
def _cast_placeholder_for_mla(quantization_data: Tuple[DataValue[attributes.QuantResultTensorType],
Optional[Dict[str, attributes.ObservedDistribution]]],
placeholder_node: AwesomeNode) \
-> Tuple[Tuple[DataValue[attributes.QuantResultTensorType],
Optional[Dict[str, attributes.ObservedDistribution]], QuantizationCast]]:
"""
Ensure that a placeholder node has a type that is supported on the MLA. If the type
is not supported, change it to a supported type and create a cast for converting data
from the original type to the supported type. The cast should be applied on the
input data before it is passed to the MLA.
:param quantization_data: Properties of the data that is passed in to the placeholder
:param placeholder_node: The placeholder node to be changed
:return: Placeholder node modified to have a type that is supported on MLA
and a cast to be applied to the placeholder's input
"""
assert isinstance(placeholder_node.ir, SiMaIR)
assert isinstance(placeholder_node.ir.operation, operations.PlaceholderOp)
assert isinstance(placeholder_node.ir.attrs, operations.PlaceholderAttrs)
assert placeholder_node.ir.quant_attrs is None
qtype, value_distribution = quantization_data
assert isinstance(qtype, TensorValue)
old_type = qtype.value.type
if old_type.scalar in (ScalarType.int8, ScalarType.int16, ScalarType.int32, ScalarType.bfloat16):
# No conversion needed
# Set placeholder node's type to match the input type
placeholder_node.ir.attrs.type = dataclasses.replace(placeholder_node.ir.attrs.type, scalar=old_type.scalar)
return quantization_data, IdentityCast() # No conversion needed
assert old_type.scalar == ScalarType.float32
# Placeholder values are quantized depending on quantization_precision value specified
# in QuantizationConfigs.
qtype = placeholder_node.ir.quant_config.quantization_precision.get().to_scalar_type()
if qtype == ScalarType.bfloat16:
new_qtype = attributes.QuantResultTensorType.from_type(TensorType(qtype, old_type.shape))
# Change the placeholder's type
placeholder_node.ir.attrs.type = dataclasses.replace(
placeholder_node.ir.attrs.type, scalar=qtype
)
new_data = TensorValue(new_qtype)
return (new_data, value_distribution), ConvertCast(old_type.shape, old_type.scalar, qtype)
else:
return _quantize_and_cast_placeholder(placeholder_node, old_type.shape, qtype, value_distribution)
def _quantize_and_cast_placeholder(placeholder_node: AwesomeNode, shape: Tuple[int, ...], qtype: ScalarType,
value_distribution: Optional[attributes.ObservedDistribution]):
# Change the placeholder's type
placeholder_node.ir.attrs.type = dataclasses.replace(placeholder_node.ir.attrs.type, scalar=qtype)
# Insert a conversion from float32 to int8/int32
# Compute quantization based on calibration result
quantization = _compute_placeholder_quantization(placeholder_node)
new_qtype = attributes.QuantResultTensorType(TensorType(qtype, shape), quantization,
RequantMethod.fractional_zero)
new_data = TensorValue(new_qtype)
num_bits = scalar_byte_size(qtype) * 8
return (new_data, value_distribution), QuantCast(shape, quantization.scale, quantization.zero_point,
num_bits, qtype)
def _make_subnet_input_mla_compatible(inputs: Dict[InputName, operations.QuantizationTensorData], net: AwesomeNet) \
-> Tuple[Dict[InputName, operations.QuantizationTensorData], InputsQuantCast]:
"""
Convert all of the net's input placeholders to types that are supported on MLA.
:return: Casts that should be applied to the net's inputs.
"""
assert list(inputs.keys()) == net.input_node_names
input_casts: Dict[InputName, QuantizationCast] = dict()
new_inputs: Dict[InputName, operations.QuantizationTensorData] = dict()
for input_name, (quant_type, distribution, calibration_data) in inputs.items():
node = net.nodes[input_name]
(new_input_quant, new_input_distribution), input_cast = \
_cast_placeholder_for_mla((quant_type, distribution), node)
if calibration_data is not None:
calibration_data = cast_calibration_inputs(calibration_data, input_cast)
new_inputs[input_name] = (new_input_quant, new_input_distribution, calibration_data)
input_casts[input_name] = input_cast
return new_inputs, InputsQuantCast(input_casts)
def _visit_subnet(state: QuantizeState, inputs: Dict[InputName, operations.QuantizationTensorData],
node_name_of_net: Optional[NodeName], net: AwesomeNet) \
-> operations.QuantizationTensorData:
"""
Do quantization for a subnetwork.
:param state: Shared data for quantization
:param inputs: Quantization of the network's inputs
:param node_name_of_net: Name of the node that contains the network. None for the top-level network.
:param net: Network to quantize
:return: Quantization of the network's outputs
"""
if net.backend == Backend.MLA:
# Inputs must be quantized tensors
assert node_name_of_net is not None
inputs, input_cast = _make_subnet_input_mla_compatible(inputs, net)
state.casts.insert(node_name_of_net, input_cast)
subnet_state = state.enter_awesome_net_scope(net, state.num_calibration_inputs)
return traverse_network(_visit_node, subnet_state, inputs, net)
def _visit_node(state: QuantizeState, inputs: Dict[InputName, operations.QuantizationTensorData],
placeholder_value: Optional[operations.QuantizationTensorData], node: AwesomeNode) \
-> operations.QuantizationTensorData:
"""
Quantize one network node as part of the quantization process.
"""
if isinstance(node.ir, AwesomeNet):
assert placeholder_value is None
return _visit_subnet(state, inputs, node.name, node.ir)
elif isinstance(node.ir, SiMaIR):
return _visit_operation(state, inputs, placeholder_value, node)
else:
raise TypeError("Unexpected IR type")
@dataclass
class _Counter:
value: int = 0
def next(self) -> int:
v = self.value
self.value = v + 1
return v
def _create_quantization_cast_node(net: AwesomeNet, name_counter: _Counter,
cast: QuantizationCast, input_node_name: NodeName) -> NodeName:
"""
Create nodes for a quantization cast.
:param net: Network where new nodes are inserted
:param name_counter: Counter used to create new node IDs
:param cast: The cast to convert to nodes
:param input_node_name: The node whose output should be cast
:return: Node that has the casted output. It may be input_node or a new node.
"""
if isinstance(cast, IdentityCast):
# No node is created
new_input_node_name = input_node_name
elif isinstance(cast, QuantCast):
# Quantize a tensor
backend = Backend.NONE if net.backend == Backend.MLA else Backend.EV
new_node = create_quantization_node(input_node_name, name_counter.next(), cast, backend)
new_input_node_name = new_node.name
net.nodes[new_input_node_name] = new_node
elif isinstance(cast, DequantCast):
# Dequantize a tensor
backend = Backend.NONE if net.backend == Backend.MLA else Backend.EV
new_node = create_dequantization_node(input_node_name, name_counter.next(), cast, backend)
new_input_node_name = new_node.name
net.nodes[new_input_node_name] = new_node
elif isinstance(cast, RequantCast):
input_type = ScalarType.int32 if cast.input_32_bit else ScalarType.int16
new_node = create_requantization_node(input_node_name, name_counter.next(), TensorType(input_type, cast.shape),
cast.get_input_quantization(), cast.get_output_quantization(),
cast.requant_method, create_requantization_from_cast(cast))
new_input_node_name = new_node.name
net.nodes[new_input_node_name] = new_node
elif isinstance(cast, ConvertCast):
new_node = create_cast_node(input_node_name, name_counter.next(), cast.shape, cast.in_type, cast.out_type)
new_input_node_name = new_node.name
net.nodes[new_input_node_name] = new_node
elif isinstance(cast, TupleCast):
# Cast individual elements of a tuple
input_type = net.nodes[input_node_name].get_type().output
assert isinstance(input_type, TupleValue)
assert len(input_type.elements) == len(cast.elements)
cast_name_prefix = "quantize_" + str(name_counter.next()) # For new node names
# _create_tuple_get_item_nodes only handles tensors, not tuples
# Note that TensorType
input_tensor_types = []
for t in input_type.elements:
input_tensor_types.append(get_expected_tensor_value(t))
# Make nodes that get tuple items
input_quant = net.nodes[input_node_name].ir.calib_attrs.quant
tuple_nodes = create_tuple_get_item_nodes(input_node_name, input_tensor_types, cast_name_prefix, input_quant)
net.nodes.update({n.name: n for n in tuple_nodes})
# Cast the tuple items
cast_nodes = []
for i, (e_node, e_cast) in enumerate(zip(tuple_nodes, cast.elements)):
cast_nodes.append(_create_quantization_cast_node(net, name_counter, e_cast, e_node.name))
# Make node to construct new tuple
new_node = create_tuple_output_node([net.nodes[n] for n in cast_nodes], cast_name_prefix)
new_input_node_name = new_node.name
net.nodes[new_input_node_name] = new_node
else:
raise TypeError("Unexpected cast type")
return new_input_node_name
def _insert_quantization_casts_for_node(net: AwesomeNet, name_counter: _Counter,
node: AwesomeNode, casts: InputsQuantCast) -> None:
"""
Insert quantization casts for the inputs of one node.
New nodes are created, network edges are modified, and the nodes are added to the network.
:param net: Network where new nodes are inserted
:param name_counter: Counter used to create new node IDs
:param node: The node whose inputs will be modified
:param casts: The set of casts to apply to the node's inputs
"""
assert list(casts.casts.keys()) == node.input_names
new_input_node_names = []
for old_input_node_name, input_cast in zip(node.input_node_names, casts.casts.values()):
new_input_node_name = _create_quantization_cast_node(net, name_counter, input_cast, old_input_node_name)
new_input_node_names.append(new_input_node_name)
node.input_node_names = new_input_node_names
def _insert_quantization_casts(net: AwesomeNet, name_counter: _Counter, casts: QuantizationCasts) -> None:
"""
Insert quantization casts into the net.
:param net: Net to modify
:param name_counter: Counter used to create unique names
:param casts: Casts to insert
"""
# Traverse a copy of nodes, as the algorithm may modify the set of nodes
net_nodes = list(net.nodes.values())
for node in net_nodes:
node_casts = casts.casts.get(node.name)
if node_casts is not None:
_insert_quantization_casts_for_node(net, name_counter, node, node_casts)
# Insert into subnetworks
if isinstance(node.ir, AwesomeNet):
_insert_quantization_casts(node.ir, name_counter, casts)
def _is_valid_quantize_data(x: Any):
calibration_data, (compat_attrs, compat_quant_attrs) = x
assert isinstance(calibration_data, DataValue)
assert isinstance(compat_attrs, attributes.AwesomeAttributes)
assert isinstance(compat_quant_attrs, attributes.AwesomeQuantAttrBase)
def _visit_operation(state: QuantizeState, inputs: Dict[InputName, operations.QuantizationTensorData],
placeholder_value: Optional[operations.QuantizationTensorData],
node: AwesomeNode) -> operations.QuantizationTensorData:
"""
Quantize a node that contains an AwesomeOperation.
:param state: Shared state of the quantization algorithm
:param inputs: Quantization of the node's inputs
:param placeholder_value: Quantization of the placeholder's value, if this is a
placeholder node. None, otherwise.
:param node: Node to quantize
:return: Quantization of the node's outputs
"""
assert isinstance(node.ir, SiMaIR)
quantization_enabled = node.name not in state.disallowed_quantization_nodes
# Node has to be calibrated prior to quantization
assert node.status == Status.CALIBRATED
# Check for type errors in the inputs.
# The check ignores placeholder input types because AwesomeNet.run passes in other data,
# which is ignored, for placeholders
if not isinstance(node.ir.operation, operations.PlaceholderOp):
assert(_is_valid_quantize_data(x) for x in inputs.values())
outputs, node_casts = node.ir.quantize(inputs, placeholder_value, quantization_enabled,
LogNodeReporter(node.name))
assert isinstance(outputs[0], DataValue)
# Transition status
if quantization_enabled:
node.status = Status.SIMA_QUANTIZED
if node_casts is not None:
state.casts.insert(node.name, node_casts)
# set of layer-wise outputs (quant_outputs) will be None unless Iterative Bias Correction is used
if (node.ir.quant_config.biascorr_type.get() == BiasCorrectionType.ITERATIVE):
quant_outputs_list = quantized_operation_execution(state, inputs, placeholder_value, node, node_casts)
else:
quant_outputs_list = None
output_list = list(outputs)
output_list.append(quant_outputs_list)
outputs = tuple(output_list)
return outputs
def _make_network_inputs(net: AwesomeNet, calibration_dataset: Iterable) \
-> Dict[InputName, operations.QuantizationTensorData]:
# The algorithm is rewritten from AwesomeNet.run and it inherits the requirement to have
# an input dict entry for placeholders, even though the value in the dict is ignored.
input_dict: Dict[InputName, operations.QuantizationTensorData] = {}
for name in net.input_node_names:
node = net.nodes[name]
assert isinstance(node.ir, SiMaIR)
assert isinstance(node.ir.operation, operations.PlaceholderOp)
input_type = get_expected_tensor_value(node.get_type().output)
# The implementation puns InputName as NodeName for AwesomeNet
if calibration_dataset is not None:
input_list = [input[name] for input in calibration_dataset]
else:
input_list = None
input_dict[cast(InputName, name)] = (TensorValue(attributes.QuantResultTensorType.from_type(input_type)), None,
input_list)
return input_dict
[docs]
def quantized_operation_execution(state: QuantizeState, inputs: Dict[InputName, operations.QuantizationTensorData],
placeholder_value: Optional[operations.QuantizationTensorData],
node: AwesomeNode, node_casts: InputsQuantCast) -> List[np.ndarray]:
# Ada round calculating output of the quantized node
calibration_inputs = dict()
if isinstance(node.ir.operation, operations.PlaceholderOp):
assert placeholder_value is not None
calibration_inputs['data'] = placeholder_value[2]
else:
for i, input_name in enumerate(inputs.keys()):
calibration_inputs[input_name] = cast_calibration_inputs(inputs[input_name][2], node_casts.casts[input_name])
quant_outputs_list = list()
for i in range(state.num_calibration_inputs):
run_quant_input_dict = {}
if placeholder_value is None:
for input_name in inputs:
run_quant_input_dict[input_name] = calibration_inputs[input_name][i]
else:
run_quant_input_dict['data'] = calibration_inputs['data'][i]
quant_output_dict = dict()
executor = create_node_quant_executor(False)
executor(node, run_quant_input_dict, quant_output_dict)
quant_outputs = quant_output_dict[node.name]
quant_outputs_list.append(quant_outputs)
return quant_outputs_list
[docs]
class Quantize(BaseTransform):
"""Quantizes an AwesomeNet"""
def __call__(self, net: AwesomeNet, input_dataset: Iterable | None) -> None:
_msg = "Running quantization ..."
print(colored(_msg, "green"), end="\r")
# Do quantization
old_output_type: DataValue[TensorType] = net.nodes[net.output_node_name].get_type().output
old_output_qtype = map_data_value(attributes.QuantResultTensorType.from_type, old_output_type)
quantization_state = QuantizeState.initialize(len(input_dataset) if input_dataset is not None else 0)
network_inputs = _make_network_inputs(net, input_dataset)
quantized_output_qtype, _, _ = _visit_subnet(quantization_state, network_inputs, None, net)
output_cast = operations.make_quantization_cast(quantized_output_qtype, old_output_qtype)
# Insert the produced casts into the network
name_counter = _Counter()
_insert_quantization_casts(net, name_counter, quantization_state.casts)
net.output_node_name = _create_quantization_cast_node(net, name_counter, output_cast, net.output_node_name)
# Recalculate topological ordering of the network and subnetworks
def do_topological_sort(net: AwesomeNet):
net.topological_sort()
for node in net.nodes.values():
if isinstance(node.ir, AwesomeNet):
do_topological_sort(node.ir)
do_topological_sort(net)
# Update AwesomeNet status
update_awesomenet_status(net, Status.SIMA_QUANTIZED)
print(colored(_msg, "green") + colored("DONE", "yellow"))