Source code for afe.ir.transform.quantization_transforms

#########################################################
# Copyright (C) 2020 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Joey Chou
#########################################################
import copy
import dataclasses
from dataclasses import dataclass

import numpy as np
from termcolor import colored
from typing import Dict, Callable, Optional, Any, Tuple, TypeVar, Set, cast, List, Iterable

import afe.ir.attributes as attributes
import afe.ir.operations as operations
from afe.ir.execute import create_node_quant_executor
from afe.backends import Backend
from afe.core.configs import QuantizationConfigs, update_quantization_configs, QuantizationPrecision, Opt, \
    merge_quantization_configs
from afe.ir.build_node import create_tuple_output_node, create_tuple_get_item_nodes, create_quantization_node, \
    create_dequantization_node, create_requantization_node, create_cast_node
from afe.ir.quantization_utils import create_requantization_from_cast, cast_calibration_inputs
from afe.ir.defines import (
    NodeName, Status, DataValue, InputName, TensorValue, IdentityCast, QuantCast, DequantCast,
    InputsQuantCast, QuantizationCasts, QuantizationCast, Quantization, map_data_value, TupleCast, TupleValue,
    get_expected_tensor_value, RequantCast, RequantMethod, LogNodeReporter, RequantizationMode, ConvertCast, BiasCorrectionType
)
from afe.ir.net import AwesomeNet, update_awesomenet_status
from afe.ir.node import AwesomeNode, node_is_awesomenet, node_is_external, node_is_sima_ir, node_is_backend_ir
from afe.ir.sima_ir import SiMaIR
from afe.ir.tensor_type import ScalarType, TensorType, scalar_byte_size
from afe.ir.transform.base_transform import BaseTransform


################################
#  Node to node mappings
################################

[docs] class UpdateQuantizationConfigs(BaseTransform): """ Update quantization configs to each node in AwesomeNet :attribute asymmetry: bool. Set to True if users want to do asymmetry quantization for all layers or False for symmetric quantization unless some layers are specific configured in the custom_quantization_configs. :attribute per_channel: bool. Set to True if users want to do per-channel quantization for all layers or False for per-tensor quantization unless some layers are specific configured in the custom_quantization_configs. :attribute leaky_relu_uses_udf: Default is True. If Ture, AFE will use UDF for LeakyRelu operation. Else AFE will breakdown the LeakyRelu into multiple elementwise operationis unless some layers are specific configured in the custom_quantization_configs. :attribute custom_quantization_configs: Optional[Dict[NodeName, Dict[str, Any]]]. A dictionary using NodeName as keys. The value to each key is a dictionary of the AwesomeQuantAttr's field names and sets target configuration. Example ------- The example shows how a custom_quantization_configs looks like to config the output_int32 field in a Conv2DQuantAttrs in a output conv2d_add node to True. custom_quantization_configs = {"MLA_1/conv2d_add_84": {"output_int32": True}} """ _quant_configs: QuantizationConfigs _custom_quantization_configs: Optional[Dict[NodeName, Dict[str, Any]]] def __init__(self, quant_configs: QuantizationConfigs, custom_quantization_configs: Optional[Dict[NodeName, Dict[str, Any]]] = None): self._quant_configs = quant_configs self._custom_quantization_configs = custom_quantization_configs def __call__(self, net: AwesomeNet) -> None: for node in net.iter_nodes_recursive(): if node_is_sima_ir(node): # Initialize the unassigned configuration fields, using quant_configs. # Do not override assigned fields. c = merge_quantization_configs(config1=node.ir.quant_config, config2=self._quant_configs) # Override fields using custom_quantization_configs. override_dict = None if self._custom_quantization_configs is None \ else self._custom_quantization_configs.get(node.name) if override_dict is not None: override_c = QuantizationConfigs(**{k: Opt(v) for k, v in override_dict.items()}) c = merge_quantization_configs(config1=override_c, config2=c) node.ir.quant_config = c elif node_is_backend_ir(node) or node_is_awesomenet(node): pass else: raise TypeError("Unknown node type")
[docs] def add_external_nodes_to_float_node_list(net: AwesomeNet) -> None: """ Adding names of external nodes to the net.float_node_list attributes. net.float_node_list is used to decide whether to apply quantization transforms on a node. Parameters ---------- :param net: AwesomeNet. """ for node in net.nodes.values(): if node_is_awesomenet(node): add_external_nodes_to_float_node_list(node.ir) elif node_is_external(node): net.extend_float_node_list([node.name])
_A = TypeVar("_A") _State = TypeVar("_State") def _get_subnet_nodes_input_names(net: AwesomeNet) -> List: """ Get input node names of a net's nodes. """ input_node_names = [] if net.backend == Backend.MLA: for node_name in net.execution_order: node = net.nodes[node_name] assert isinstance(node.ir, SiMaIR) input_node_names.extend(node.input_node_names if not isinstance(node.ir.operation, operations.PlaceholderOp) else []) return input_node_names
[docs] def traverse_network(visit_node: Callable[[_State, Dict[InputName, _A], Optional[_A], AwesomeNode], _A], state: _State, inputs: Dict[InputName, _A], net: AwesomeNet) -> _A: """ Traverse a network in topological order and process each node using visit_function. Values are propagated along graph edges. The network's topological order must have been computed. The topological order upon entry to the function is used. :param visit_node: Processing to do on a network node. Processing can access shared state and can read results from processing the node's inputs. Processing returns a result to be passed to other nodes. It may modify the node and the network. :param net: Network to traverse. :param state: Mutable data used by visit_node. :param inputs: An ordered dict of values for the net's inputs. :return: Result of processing the network's output node. """ assert net.input_node_names == list(inputs.keys()), "Input keys do not match net's inputs" subnet_nodes_input_names = _get_subnet_nodes_input_names(net) # Visit nodes in topological order node_values: Dict[NodeName, _A] = {} for node_name in net.execution_order.copy(): node = net.nodes[node_name] if isinstance(node.ir, SiMaIR) and isinstance(node.ir.operation, operations.PlaceholderOp): node_inputs = {} # Special handling because of how PlaceholderOp.input_names works placeholder_value = inputs[cast(InputName, node_name)] else: node_inputs = {parameter: node_values[argument] for parameter, argument in zip(node.input_names, node.input_node_names)} placeholder_value = None node_value = visit_node(state, node_inputs, placeholder_value, node) node_values[node_name] = node_value net._prune_unneeded_activations_from_memory(node_values, node_name, protected_names=net.output_node_name) return node_values[net.output_node_name]
@dataclass
[docs] class QuantizeState: """Data used by the quantization algorithm during traversal of a network's nodes.""" # Casts that are produced during quantization. # They are saved here while quantizing nodes. # After all nodes are quantized, the casts are inserted.
[docs] casts: QuantizationCasts
# Names of nodes for which quantization is disallowed at the current subgraph. # It is used by graph analyzer to measure the effect of not quantizing a node.
[docs] disallowed_quantization_nodes: Set[NodeName]
[docs] num_calibration_inputs: int
@classmethod
[docs] def initialize(cls, num_calibration_inputs: int) -> "QuantizeState": """ Initialize a new QuantizeState for the beginning of the algorithm. """ return QuantizeState(QuantizationCasts(), set(), num_calibration_inputs)
[docs] def enter_awesome_net_scope(self, net: AwesomeNet, num_calibration_inputs: int) -> "QuantizeState": """ Prepare to process the nodes in net. Returns a new QuantizeState with information about the nodes in net. Some mutable data is shared with the original QuantizeState. """ # Share the mutable data in 'casts'. # Use the subgraph's float node list. return QuantizeState(self.casts, set(net.float_node_list), num_calibration_inputs)
def _compute_placeholder_quantization(placeholder_node: AwesomeNode) -> Quantization: """ Compute the quantization of a placeholder node based on the result of calibration. """ assert isinstance(placeholder_node.ir, SiMaIR) assert isinstance(placeholder_node.ir.operation, operations.PlaceholderOp) assert placeholder_node.ir.calib_attrs.observer is not None qrange = None # If int16 quantization is enabled, specify int16 range if placeholder_node.ir.quant_config.quantization_precision.get().is_int16_precision(): asymmetry = placeholder_node.ir.quant_config.asymmetry.get() qrange = (-32767, 32767) if not asymmetry else (-32768, 32767) quant = placeholder_node.ir.calib_attrs.observer.calculate_quantization(qrange) assert isinstance(quant, TensorValue) return quant.value def _cast_placeholder_for_mla(quantization_data: Tuple[DataValue[attributes.QuantResultTensorType], Optional[Dict[str, attributes.ObservedDistribution]]], placeholder_node: AwesomeNode) \ -> Tuple[Tuple[DataValue[attributes.QuantResultTensorType], Optional[Dict[str, attributes.ObservedDistribution]], QuantizationCast]]: """ Ensure that a placeholder node has a type that is supported on the MLA. If the type is not supported, change it to a supported type and create a cast for converting data from the original type to the supported type. The cast should be applied on the input data before it is passed to the MLA. :param quantization_data: Properties of the data that is passed in to the placeholder :param placeholder_node: The placeholder node to be changed :return: Placeholder node modified to have a type that is supported on MLA and a cast to be applied to the placeholder's input """ assert isinstance(placeholder_node.ir, SiMaIR) assert isinstance(placeholder_node.ir.operation, operations.PlaceholderOp) assert isinstance(placeholder_node.ir.attrs, operations.PlaceholderAttrs) assert placeholder_node.ir.quant_attrs is None qtype, value_distribution = quantization_data assert isinstance(qtype, TensorValue) old_type = qtype.value.type if old_type.scalar in (ScalarType.int8, ScalarType.int16, ScalarType.int32, ScalarType.bfloat16): # No conversion needed # Set placeholder node's type to match the input type placeholder_node.ir.attrs.type = dataclasses.replace(placeholder_node.ir.attrs.type, scalar=old_type.scalar) return quantization_data, IdentityCast() # No conversion needed assert old_type.scalar == ScalarType.float32 # Placeholder values are quantized depending on quantization_precision value specified # in QuantizationConfigs. qtype = placeholder_node.ir.quant_config.quantization_precision.get().to_scalar_type() if qtype == ScalarType.bfloat16: new_qtype = attributes.QuantResultTensorType.from_type(TensorType(qtype, old_type.shape)) # Change the placeholder's type placeholder_node.ir.attrs.type = dataclasses.replace( placeholder_node.ir.attrs.type, scalar=qtype ) new_data = TensorValue(new_qtype) return (new_data, value_distribution), ConvertCast(old_type.shape, old_type.scalar, qtype) else: return _quantize_and_cast_placeholder(placeholder_node, old_type.shape, qtype, value_distribution) def _quantize_and_cast_placeholder(placeholder_node: AwesomeNode, shape: Tuple[int, ...], qtype: ScalarType, value_distribution: Optional[attributes.ObservedDistribution]): # Change the placeholder's type placeholder_node.ir.attrs.type = dataclasses.replace(placeholder_node.ir.attrs.type, scalar=qtype) # Insert a conversion from float32 to int8/int32 # Compute quantization based on calibration result quantization = _compute_placeholder_quantization(placeholder_node) new_qtype = attributes.QuantResultTensorType(TensorType(qtype, shape), quantization, RequantMethod.fractional_zero) new_data = TensorValue(new_qtype) num_bits = scalar_byte_size(qtype) * 8 return (new_data, value_distribution), QuantCast(shape, quantization.scale, quantization.zero_point, num_bits, qtype) def _make_subnet_input_mla_compatible(inputs: Dict[InputName, operations.QuantizationTensorData], net: AwesomeNet) \ -> Tuple[Dict[InputName, operations.QuantizationTensorData], InputsQuantCast]: """ Convert all of the net's input placeholders to types that are supported on MLA. :return: Casts that should be applied to the net's inputs. """ assert list(inputs.keys()) == net.input_node_names input_casts: Dict[InputName, QuantizationCast] = dict() new_inputs: Dict[InputName, operations.QuantizationTensorData] = dict() for input_name, (quant_type, distribution, calibration_data) in inputs.items(): node = net.nodes[input_name] (new_input_quant, new_input_distribution), input_cast = \ _cast_placeholder_for_mla((quant_type, distribution), node) if calibration_data is not None: calibration_data = cast_calibration_inputs(calibration_data, input_cast) new_inputs[input_name] = (new_input_quant, new_input_distribution, calibration_data) input_casts[input_name] = input_cast return new_inputs, InputsQuantCast(input_casts) def _visit_subnet(state: QuantizeState, inputs: Dict[InputName, operations.QuantizationTensorData], node_name_of_net: Optional[NodeName], net: AwesomeNet) \ -> operations.QuantizationTensorData: """ Do quantization for a subnetwork. :param state: Shared data for quantization :param inputs: Quantization of the network's inputs :param node_name_of_net: Name of the node that contains the network. None for the top-level network. :param net: Network to quantize :return: Quantization of the network's outputs """ if net.backend == Backend.MLA: # Inputs must be quantized tensors assert node_name_of_net is not None inputs, input_cast = _make_subnet_input_mla_compatible(inputs, net) state.casts.insert(node_name_of_net, input_cast) subnet_state = state.enter_awesome_net_scope(net, state.num_calibration_inputs) return traverse_network(_visit_node, subnet_state, inputs, net) def _visit_node(state: QuantizeState, inputs: Dict[InputName, operations.QuantizationTensorData], placeholder_value: Optional[operations.QuantizationTensorData], node: AwesomeNode) \ -> operations.QuantizationTensorData: """ Quantize one network node as part of the quantization process. """ if isinstance(node.ir, AwesomeNet): assert placeholder_value is None return _visit_subnet(state, inputs, node.name, node.ir) elif isinstance(node.ir, SiMaIR): return _visit_operation(state, inputs, placeholder_value, node) else: raise TypeError("Unexpected IR type") @dataclass class _Counter: value: int = 0 def next(self) -> int: v = self.value self.value = v + 1 return v def _create_quantization_cast_node(net: AwesomeNet, name_counter: _Counter, cast: QuantizationCast, input_node_name: NodeName) -> NodeName: """ Create nodes for a quantization cast. :param net: Network where new nodes are inserted :param name_counter: Counter used to create new node IDs :param cast: The cast to convert to nodes :param input_node_name: The node whose output should be cast :return: Node that has the casted output. It may be input_node or a new node. """ if isinstance(cast, IdentityCast): # No node is created new_input_node_name = input_node_name elif isinstance(cast, QuantCast): # Quantize a tensor backend = Backend.NONE if net.backend == Backend.MLA else Backend.EV new_node = create_quantization_node(input_node_name, name_counter.next(), cast, backend) new_input_node_name = new_node.name net.nodes[new_input_node_name] = new_node elif isinstance(cast, DequantCast): # Dequantize a tensor backend = Backend.NONE if net.backend == Backend.MLA else Backend.EV new_node = create_dequantization_node(input_node_name, name_counter.next(), cast, backend) new_input_node_name = new_node.name net.nodes[new_input_node_name] = new_node elif isinstance(cast, RequantCast): input_type = ScalarType.int32 if cast.input_32_bit else ScalarType.int16 new_node = create_requantization_node(input_node_name, name_counter.next(), TensorType(input_type, cast.shape), cast.get_input_quantization(), cast.get_output_quantization(), cast.requant_method, create_requantization_from_cast(cast)) new_input_node_name = new_node.name net.nodes[new_input_node_name] = new_node elif isinstance(cast, ConvertCast): new_node = create_cast_node(input_node_name, name_counter.next(), cast.shape, cast.in_type, cast.out_type) new_input_node_name = new_node.name net.nodes[new_input_node_name] = new_node elif isinstance(cast, TupleCast): # Cast individual elements of a tuple input_type = net.nodes[input_node_name].get_type().output assert isinstance(input_type, TupleValue) assert len(input_type.elements) == len(cast.elements) cast_name_prefix = "quantize_" + str(name_counter.next()) # For new node names # _create_tuple_get_item_nodes only handles tensors, not tuples # Note that TensorType input_tensor_types = [] for t in input_type.elements: input_tensor_types.append(get_expected_tensor_value(t)) # Make nodes that get tuple items input_quant = net.nodes[input_node_name].ir.calib_attrs.quant tuple_nodes = create_tuple_get_item_nodes(input_node_name, input_tensor_types, cast_name_prefix, input_quant) net.nodes.update({n.name: n for n in tuple_nodes}) # Cast the tuple items cast_nodes = [] for i, (e_node, e_cast) in enumerate(zip(tuple_nodes, cast.elements)): cast_nodes.append(_create_quantization_cast_node(net, name_counter, e_cast, e_node.name)) # Make node to construct new tuple new_node = create_tuple_output_node([net.nodes[n] for n in cast_nodes], cast_name_prefix) new_input_node_name = new_node.name net.nodes[new_input_node_name] = new_node else: raise TypeError("Unexpected cast type") return new_input_node_name def _insert_quantization_casts_for_node(net: AwesomeNet, name_counter: _Counter, node: AwesomeNode, casts: InputsQuantCast) -> None: """ Insert quantization casts for the inputs of one node. New nodes are created, network edges are modified, and the nodes are added to the network. :param net: Network where new nodes are inserted :param name_counter: Counter used to create new node IDs :param node: The node whose inputs will be modified :param casts: The set of casts to apply to the node's inputs """ assert list(casts.casts.keys()) == node.input_names new_input_node_names = [] for old_input_node_name, input_cast in zip(node.input_node_names, casts.casts.values()): new_input_node_name = _create_quantization_cast_node(net, name_counter, input_cast, old_input_node_name) new_input_node_names.append(new_input_node_name) node.input_node_names = new_input_node_names def _insert_quantization_casts(net: AwesomeNet, name_counter: _Counter, casts: QuantizationCasts) -> None: """ Insert quantization casts into the net. :param net: Net to modify :param name_counter: Counter used to create unique names :param casts: Casts to insert """ # Traverse a copy of nodes, as the algorithm may modify the set of nodes net_nodes = list(net.nodes.values()) for node in net_nodes: node_casts = casts.casts.get(node.name) if node_casts is not None: _insert_quantization_casts_for_node(net, name_counter, node, node_casts) # Insert into subnetworks if isinstance(node.ir, AwesomeNet): _insert_quantization_casts(node.ir, name_counter, casts) def _is_valid_quantize_data(x: Any): calibration_data, (compat_attrs, compat_quant_attrs) = x assert isinstance(calibration_data, DataValue) assert isinstance(compat_attrs, attributes.AwesomeAttributes) assert isinstance(compat_quant_attrs, attributes.AwesomeQuantAttrBase) def _visit_operation(state: QuantizeState, inputs: Dict[InputName, operations.QuantizationTensorData], placeholder_value: Optional[operations.QuantizationTensorData], node: AwesomeNode) -> operations.QuantizationTensorData: """ Quantize a node that contains an AwesomeOperation. :param state: Shared state of the quantization algorithm :param inputs: Quantization of the node's inputs :param placeholder_value: Quantization of the placeholder's value, if this is a placeholder node. None, otherwise. :param node: Node to quantize :return: Quantization of the node's outputs """ assert isinstance(node.ir, SiMaIR) quantization_enabled = node.name not in state.disallowed_quantization_nodes # Node has to be calibrated prior to quantization assert node.status == Status.CALIBRATED # Check for type errors in the inputs. # The check ignores placeholder input types because AwesomeNet.run passes in other data, # which is ignored, for placeholders if not isinstance(node.ir.operation, operations.PlaceholderOp): assert(_is_valid_quantize_data(x) for x in inputs.values()) outputs, node_casts = node.ir.quantize(inputs, placeholder_value, quantization_enabled, LogNodeReporter(node.name)) assert isinstance(outputs[0], DataValue) # Transition status if quantization_enabled: node.status = Status.SIMA_QUANTIZED if node_casts is not None: state.casts.insert(node.name, node_casts) # set of layer-wise outputs (quant_outputs) will be None unless Iterative Bias Correction is used if (node.ir.quant_config.biascorr_type.get() == BiasCorrectionType.ITERATIVE): quant_outputs_list = quantized_operation_execution(state, inputs, placeholder_value, node, node_casts) else: quant_outputs_list = None output_list = list(outputs) output_list.append(quant_outputs_list) outputs = tuple(output_list) return outputs def _make_network_inputs(net: AwesomeNet, calibration_dataset: Iterable) \ -> Dict[InputName, operations.QuantizationTensorData]: # The algorithm is rewritten from AwesomeNet.run and it inherits the requirement to have # an input dict entry for placeholders, even though the value in the dict is ignored. input_dict: Dict[InputName, operations.QuantizationTensorData] = {} for name in net.input_node_names: node = net.nodes[name] assert isinstance(node.ir, SiMaIR) assert isinstance(node.ir.operation, operations.PlaceholderOp) input_type = get_expected_tensor_value(node.get_type().output) # The implementation puns InputName as NodeName for AwesomeNet if calibration_dataset is not None: input_list = [input[name] for input in calibration_dataset] else: input_list = None input_dict[cast(InputName, name)] = (TensorValue(attributes.QuantResultTensorType.from_type(input_type)), None, input_list) return input_dict
[docs] def quantized_operation_execution(state: QuantizeState, inputs: Dict[InputName, operations.QuantizationTensorData], placeholder_value: Optional[operations.QuantizationTensorData], node: AwesomeNode, node_casts: InputsQuantCast) -> List[np.ndarray]: # Ada round calculating output of the quantized node calibration_inputs = dict() if isinstance(node.ir.operation, operations.PlaceholderOp): assert placeholder_value is not None calibration_inputs['data'] = placeholder_value[2] else: for i, input_name in enumerate(inputs.keys()): calibration_inputs[input_name] = cast_calibration_inputs(inputs[input_name][2], node_casts.casts[input_name]) quant_outputs_list = list() for i in range(state.num_calibration_inputs): run_quant_input_dict = {} if placeholder_value is None: for input_name in inputs: run_quant_input_dict[input_name] = calibration_inputs[input_name][i] else: run_quant_input_dict['data'] = calibration_inputs['data'][i] quant_output_dict = dict() executor = create_node_quant_executor(False) executor(node, run_quant_input_dict, quant_output_dict) quant_outputs = quant_output_dict[node.name] quant_outputs_list.append(quant_outputs) return quant_outputs_list
[docs] class Quantize(BaseTransform): """Quantizes an AwesomeNet""" def __call__(self, net: AwesomeNet, input_dataset: Iterable | None) -> None: _msg = "Running quantization ..." print(colored(_msg, "green"), end="\r") # Do quantization old_output_type: DataValue[TensorType] = net.nodes[net.output_node_name].get_type().output old_output_qtype = map_data_value(attributes.QuantResultTensorType.from_type, old_output_type) quantization_state = QuantizeState.initialize(len(input_dataset) if input_dataset is not None else 0) network_inputs = _make_network_inputs(net, input_dataset) quantized_output_qtype, _, _ = _visit_subnet(quantization_state, network_inputs, None, net) output_cast = operations.make_quantization_cast(quantized_output_qtype, old_output_qtype) # Insert the produced casts into the network name_counter = _Counter() _insert_quantization_casts(net, name_counter, quantization_state.casts) net.output_node_name = _create_quantization_cast_node(net, name_counter, output_cast, net.output_node_name) # Recalculate topological ordering of the network and subnetworks def do_topological_sort(net: AwesomeNet): net.topological_sort() for node in net.nodes.values(): if isinstance(node.ir, AwesomeNet): do_topological_sort(node.ir) do_topological_sort(net) # Update AwesomeNet status update_awesomenet_status(net, Status.SIMA_QUANTIZED) print(colored(_msg, "green") + colored("DONE", "yellow"))