Source code for afe.ir.transform.requantization_fusion

#########################################################
# Copyright (C) 2023 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Christopher Rodrigues
#########################################################
import dataclasses
import numpy as np
from typing import Dict, List, Optional, Tuple, Union, Callable, Type

from ml_kernels.requantization import (
    is_identity_requantization, is_arith_folded_requantization, is_tflite_requantization,
    is_fractional_zero_requantization, BaseRequantization, TFLiteRequantization, FractionalZeroRequantization,
    ArithFoldedRequantization, Narrowing, requantize
)
from sima_utils.logging import sima_logger

from afe.ir.attributes import (
    AddQuantAttrs, RequantizeQuantAttrs, AwesomeQuantAttrBase, AwesomeCalibAttrs, SubtractQuantAttrs,
    MultiplyQuantAttrs, ConvQuantAttrs, PoolQuantAttrs, UDFQuantAttrs, QuantizationTransformAttrs,
    ConstantQuantAttrs, AwesomeAttributes
)
from afe.ir.defines import NodeName
from afe.ir.net import AwesomeNet, Renaming, rename_mut_awesomenet
from afe.ir.node import node_is_sima_ir, AwesomeNode, node_is_constant, node_is_requantization_node
from afe.ir.operations import (
    RequantizeOp, AddActivationOp, SubtractOp, MultiplyOp, MaxPool2DOp, AvgPool2DOp, ConvAddActivationOp, UDFOp,
    ConstantOp, QuantizationTransformOp, AwesomeOperation
)
from afe.ir.quantization_utils import requantize_activation
from afe.ir.tensor_type import ScalarType, scalar_is_integral, scalar_byte_size
from afe.ir.transform.base_transform import BaseTransform

_NodeAttributes = Union[AwesomeAttributes, AwesomeQuantAttrBase]


[docs] def convert_to_arith_folded_requantization(requant: BaseRequantization[np.ndarray]) \ -> Optional[ArithFoldedRequantization[np.ndarray]]: """ Convert the given requantization to an ArithFoldedRequantization, if possible. Only convert if the ArithFoldedRequantization is exactly equivalent, including rounding and saturation behavior. Args: requant: Requantization to convert Returns: An ArithFoldedRequantization that is equivalent to the input. None if there is no equivalent ArithFoldedRequantization. """ match requant: case ArithFoldedRequantization(): return requant case FractionalZeroRequantization(sc_correction=1, zp_correction=0, narrowing=narrowing): return ArithFoldedRequantization(narrowing) case TFLiteRequantization(sc_correction=1, zp_correction=0, shift=shift, rounding=rounding, out_dtype=dtype): return ArithFoldedRequantization(Narrowing(shift, rounding, dtype)) # Else, cannot convert return None
[docs] def can_convert_to_arith_folded_requantization(requant: BaseRequantization[np.ndarray]) -> bool: """ Decide whether the given requantization can be converted to an ArithFoldedRequantization. """ return convert_to_arith_folded_requantization(requant) is not None
def _collect_successors( net: AwesomeNet, node_condition: Callable[[AwesomeNode], bool] = lambda x: True, successor_condition: Callable[[AwesomeNode], bool] = lambda x: True ) -> Dict[NodeName, List[NodeName]]: """ Find all successors of all nodes, restricting the results to nodes and successors that satisfy the given conditions. :param net: Net to analyze. :param node_condition: The condition that needs to be fulfilled for an input node to be included in the returned dictionary. :param successor_condition: The condition that needs to be fulfilled for a successor node to be included in the returned dictionary. :return: Successors of nodes. It is a dict d such that for all nodes m and n, m is an input of n iff (n in d[m]) is true. """ successors = {node_name: list() for node_name, node_value in net.nodes.items() if node_condition(node_value)} for node in net.nodes.values(): if successor_condition(node): for input_name in node.input_node_names: if node.name != input_name and node_condition(net.nodes[input_name]): successors[input_name].append(node.name) return successors def _generate_fused_node(node: AwesomeNode, successor: AwesomeNode, new_quant_attrs: _NodeAttributes) -> AwesomeNode: """ Create a fused node to replace "node" and "successor". Use the operator from node and use the given attributes. The caller must ensure that fusion is valid. This function is for fusion that eliminates a RequantizeOp node. :param node: The first node to be fused :param successor: The second node to be fused :param new_quant_attrs: Quantized attributes to use in the fused node :return: The fused node """ # Collect calibration information from the input and output new_calib_attrs = AwesomeCalibAttrs(observer=successor.ir.calib_attrs.observer, input_quant=node.ir.calib_attrs.input_quant, quant=successor.ir.calib_attrs.quant) _attrs = new_quant_attrs if isinstance(new_quant_attrs, AwesomeAttributes) else None _quant_attrs = new_quant_attrs if isinstance(new_quant_attrs, AwesomeQuantAttrBase) else None new_ir = dataclasses.replace(node.ir, _attrs=_attrs, calib_attrs=new_calib_attrs, _quant_attrs=_quant_attrs) return dataclasses.replace(node, ir=new_ir) def _modify_add_activ_attrs_for_fusion(attrs: _NodeAttributes, requant_attrs: RequantizeQuantAttrs) \ -> Optional[_NodeAttributes]: """ Modifies the quantization attributes for AddActivation node to incorporate the requantization parameters from the succeeding Requantization node. :param attrs: Quantization attributes of the original AddActivation node. :param requant_attrs: Requantization attributes of the succeeding Requantization node. :return: Quantization attributes incorporating the requantization parameters, if the Requantization node can be fused, otherwise None. """ assert isinstance(attrs, AddQuantAttrs) assert isinstance(attrs.requant, FractionalZeroRequantization) if not is_identity_requantization(attrs.requant) or attrs.requant.out_dtype != np.int32: return None if attrs.activ_attrs is not None: # Fusion would make the activation run after requantization instead of before, # so it must be requantized. scalar_type = ScalarType.int16 if attrs.input_int16 else ScalarType.int8 new_activ_attrs = requantize_activation(attrs.activ_attrs, attrs.relu_zero_point, requant_attrs.requant, scalar_type) else: new_activ_attrs = None # Compute the new output zero point new_zero_point = int(requantize(np.array(attrs.relu_zero_point), requant_attrs.requant).item()) return dataclasses.replace(attrs, requant=requant_attrs.requant, relu_zero_point=new_zero_point, activ_attrs=new_activ_attrs) def _modify_pool_attrs_for_fusion(attrs: _NodeAttributes, requant_attrs: RequantizeQuantAttrs) \ -> Optional[_NodeAttributes]: """ Modifies the quantization attributes for MaxPool or AvgPool node to incorporate the requantization parameters from the succeeding Requantization node. :param attrs: Quantization attributes of the original Pool node. :param requant_attrs: Requantization attributes of the succeeding Requantization node. :return: Quantization attributes incorporating the requantization parameters, if the Requantization node can be fused, otherwise None. """ assert isinstance(attrs, PoolQuantAttrs) assert isinstance(attrs.requant, TFLiteRequantization) if not is_identity_requantization(attrs.requant) or attrs.requant.out_dtype != np.int32: return None return dataclasses.replace(attrs, requant=requant_attrs.requant) def _modify_mul_sub_attrs_for_fusion(attrs: _NodeAttributes, requant_attrs: RequantizeQuantAttrs) \ -> Optional[_NodeAttributes]: """ Modifies the quantization attributes for Substract or Multiply node to incorporate the requantization parameters from the succeeding Requantization node. :param attrs: Quantization attributes of the original Substract or Multiply node. :param requant_attrs: Requantization attributes of the succeeding Requantization node. :return: Quantization attributes incorporating the requantization parameters, if the Requantization node can be fused, otherwise None. """ assert isinstance(attrs, (SubtractQuantAttrs, MultiplyQuantAttrs)) assert isinstance(attrs.requant, (FractionalZeroRequantization, TFLiteRequantization)) if not is_identity_requantization(attrs.requant) or attrs.requant.out_dtype != np.int32: return None return dataclasses.replace(attrs, requant=requant_attrs.requant) def _modify_conv_activ_attrs_for_fusion(attrs: _NodeAttributes, requant_attrs: RequantizeQuantAttrs) \ -> Optional[_NodeAttributes]: """ Modifies the quantization attributes for ConvAddActivation node to incorporate the requantization parameters from the succeeding Requantization node. :param attrs: Quantization attributes of the original ConvAddActivation node. :param requant_attrs: Requantization attributes of the succeeding Requantization node. :return: Quantization attributes incorporating the requantization parameters, if the Requantization node can be fused, otherwise None. """ assert isinstance(attrs, ConvQuantAttrs) assert isinstance(attrs.requant, (TFLiteRequantization, ArithFoldedRequantization)) if not is_identity_requantization(attrs.requant) or attrs.requant.out_dtype != np.int32: return None # Convert to ArithFoldedRequantization if possible match convert_to_arith_folded_requantization(requant_attrs.requant): case None: requant = requant_attrs.requant case x: requant = x # Fusion would make the activation run after requantization instead of before, # so it must be requantized. scalar_type = ScalarType.int16 if attrs.input_int16 else ScalarType.int8 new_activ_attrs = requantize_activation(attrs.activ_attrs, attrs.zero_point, requant, scalar_type) if attrs.activ_attrs is not None else None # Compute the new output zero point new_zero_point = int(requantize(np.array(attrs.zero_point), requant).item()) return dataclasses.replace(attrs, activ_attrs=new_activ_attrs, zero_point=new_zero_point, requant=requant) def _check_requantization_for_conv_fusion(requant: BaseRequantization) -> bool: """ Check whether the requantization type is supported with ConvAddActivationOp. :param requant: Requantization parameters of the Reqauntization node succeeding the ConvAddActivation node. :return: True if the requantization type is supported for Convolution, otherwise False. """ if can_convert_to_arith_folded_requantization(requant): return True # TFLiteRequantization is supported only for int8. if requant.out_dtype == np.int8: return is_arith_folded_requantization(requant) or is_tflite_requantization(requant) else: return is_arith_folded_requantization(requant) def _requantize_quantization(scale: float, zero_point: int, requant: BaseRequantization) -> Tuple[float, int]: """ Calculate parameters for the quantization node that is fused with succeeding requantization node. :param scale: Scale of the original quantization node. :param zero_point: Zero point of the original quantization node. :param requant: Requantization parameters of the succeeding requantization node. :return: Tuple containing scale, zero_point for the fused quantization node. """ # Based on requantization parametrs, calculate sc_corr and zp_corr parameters, so that quantization # parameters of the fused quantization + requantization node can be calculated as follows: # sc_fused = sc_orig * sc_corr # zp_fused = zp_orig * sc_corr + zp_corr if isinstance(requant, FractionalZeroRequantization): sc_corr = float(requant.sc_correction) * 2 ** -requant.shift zp_corr = int(requant.zp_correction * 2 ** -requant.shift) elif isinstance(requant, ArithFoldedRequantization): sc_corr = 2 ** -requant.shift zp_corr = 0 else: raise TypeError(f"Unsupported Requantization: {type(requant)}") return scale * sc_corr, int(zero_point * sc_corr) + zp_corr def _modify_quant_attrs_for_fusion(attrs: _NodeAttributes, requant_attrs: RequantizeQuantAttrs) \ -> Optional[_NodeAttributes]: """ Modifies the quantization attributes for Quantization node to incorporate the requantization parameters from the succeeding Requantization node. :param attrs: Attributes of the original Quantization node. :param requant_attrs: Requantization attributes of the succeeding Requantization node. :return: Attributes incorporating the requantization parameters, if the Requantization node can be fused, otherwise None. """ assert isinstance(attrs, QuantizationTransformAttrs) assert len(attrs.channel_params) == 1 sc, zp = attrs.channel_params[0] assert isinstance(requant_attrs.requant, (FractionalZeroRequantization, ArithFoldedRequantization)) new_sc, new_zp = _requantize_quantization(sc, zp, requant_attrs.requant) output_type = ScalarType.from_numpy(requant_attrs.requant.out_dtype) assert scalar_is_integral(output_type), f"Unsupported output type: {output_type}" return dataclasses.replace(attrs, channel_params=[(new_sc, new_zp)], num_bits=8 * scalar_byte_size(output_type), output_data_type=output_type) def _modify_udf_attrs_for_fusion(attrs: _NodeAttributes, requant_attrs: RequantizeQuantAttrs) \ -> Optional[_NodeAttributes]: """ Modifies the quantization attributes for UDF node to incorporate the requantization parameters from the succeeding Requantization node. :param attrs: Quantization attributes of the original UDF node. :param requant_attrs: Requantization attributes of the succeeding Requantization node. :return: Quantization attributes incorporating the requantization parameters, if the Requantization node can be fused, otherwise None. """ assert isinstance(attrs, UDFQuantAttrs) if not is_identity_requantization(attrs.requant) or attrs.requant.out_dtype != np.int32: return None return dataclasses.replace(attrs, requant=requant_attrs.requant) def _modify_const_attrs_for_fusion(attrs: _NodeAttributes, requant_attrs: RequantizeQuantAttrs) \ -> Optional[_NodeAttributes]: """ Modifies the quantization attributes for a constant node based on requantization The new constant value is computed and it replaces the original value. :param attrs: Quantization attributes of the original constant node. :param requant_attrs: Requantization attributes of the succeeding Requantization node. :return: Quantization attributes of the requantized constant node. """ assert isinstance(attrs, ConstantQuantAttrs) new_data = requantize(attrs.quant_data, requant_attrs.requant) return ConstantQuantAttrs(new_data)
[docs] AttributesMutatorCallable = Callable[[_NodeAttributes, RequantizeQuantAttrs], Optional[_NodeAttributes]]
[docs] RequantizationCheckCallable = Callable[[RequantizeQuantAttrs], bool]
# Dictionary containing mapping of attributes modification function and requantization parameters checker function # for operators supporting the fusion of the succeeding Requantization node. The requantization parameters checker # function is used to determine if the requantization type is compatible with the operator. The attributes # modification function is used to modify the parameters of the original node in order to incorporate the # requantization from the succeeding Requantization node. _NODE_FUSION_INTERFACE: Dict[Type[AwesomeOperation], Tuple[AttributesMutatorCallable, RequantizationCheckCallable]] = { AddActivationOp: (_modify_add_activ_attrs_for_fusion, is_fractional_zero_requantization), MaxPool2DOp: (_modify_pool_attrs_for_fusion, is_arith_folded_requantization), AvgPool2DOp: (_modify_pool_attrs_for_fusion, is_tflite_requantization), SubtractOp: (_modify_mul_sub_attrs_for_fusion, is_fractional_zero_requantization), MultiplyOp: ( _modify_mul_sub_attrs_for_fusion, lambda x: is_fractional_zero_requantization(x) or is_tflite_requantization(x) ), ConvAddActivationOp: (_modify_conv_activ_attrs_for_fusion, _check_requantization_for_conv_fusion), QuantizationTransformOp: ( _modify_quant_attrs_for_fusion, lambda x: is_arith_folded_requantization(x) or is_fractional_zero_requantization(x) ), UDFOp: (_modify_udf_attrs_for_fusion, is_arith_folded_requantization), ConstantOp: (_modify_const_attrs_for_fusion, lambda x: True) } def _try_fuse_requantization_node(node: AwesomeNode, successor: AwesomeNode) -> Optional[AwesomeNode]: """ Fuse node with successor if they can be fused. Fusion handles the pattern of a node followed by a RequantizeOp. If some conditions are satisfied, the node can be modified to do the combined operation and the RequantizeOp node can be removed. This function assumes that the graph edges allow these two nodes to be fused: there is an edge from node to successor, and there are no edges from node to anything else. It does not assume anything else about node or successor. :param node: First node to attempt to fuse :param successor: Second node to attempt to fuse :return: New node that does the fused operation. It has the same name as node. """ if not node_is_sima_ir(node): return None if not node_is_requantization_node(successor): return None node_op = node.ir.operation node_attrs = node.ir.quant_attrs if not isinstance(node_op, QuantizationTransformOp) else node.ir.attrs if node_attrs is None: # Only quantized nodes and quantization nodes can have a requantization as part of the node return None requant_quant_attrs = successor.ir.quant_attrs if requant_quant_attrs is None: # Only a quantized RequantizeOp can be fused return None assert isinstance(requant_quant_attrs, RequantizeQuantAttrs) # Successor is suitable for fusion. Try to do fusion for the node's operator. attrs_mutator_fn, requant_check_fn = _NODE_FUSION_INTERFACE.get( type(node_op), (lambda x, y: None, lambda x: False) ) # Check if requantization node is suitable for fusion with preceding node. if not requant_check_fn(requant_quant_attrs.requant): return None # Try modifying the node's quantization attributes. new_quant_attrs = attrs_mutator_fn(node_attrs, requant_quant_attrs) return None if new_quant_attrs is None else _generate_fused_node(node, successor, new_quant_attrs) def _create_requantization_fusion_list(net: AwesomeNet) \ -> List[Tuple[AwesomeNode, AwesomeNode, AwesomeNode]]: """ Attempt fusion on each node in the net. For each (node, successor_node) pair that the fusion is successful for, create a list entry consisting of a (node, successor_node, fused_node) tuple. Return the list of fusion entries for further processing. The algorithm is as follows: 1. Create a dictionary of {node_name: successor_node_names} entries defining successors of each node in a subgraph. 2. Try fusing nodes with its successors. Only nodes that satisfy certain conditions can be fused: - Node must have exactly one successor node - Node must not be an output node of a given subgraph - Only SiMaIR nodes can be fused with RequantizeOp nodes - For each operator in a SiMaIR node, there are additional conditions that need to be fulfilled (defined in _NODE_FUSION_INTERFACE) :param net: AwesomeNet (subgraph) for which the fusion is attempted. :return: The list of (node, successor_node, fused_node) tuples defining the result of the fusion for a node, successor_node pair. """ successors: Dict[NodeName, List[NodeName]] = _collect_successors(net) # Attempt fusion on each node in this subgraph. # First, scan and find all nodes to fuse. # Items in fusion_list are (a, b, r) where r will replace a and b. fusion_list: List[Tuple[AwesomeNode, AwesomeNode, AwesomeNode]] = [] for node_name, node in net.nodes.items(): # This node's output must be used only in the # input of a requantization node if node_name == net.output_node_name or len(successors[node_name]) != 1: continue successor_name, = successors[node_name] successor = net.nodes[successor_name] fuse_result = _try_fuse_requantization_node(node, successor) if fuse_result is None: continue # These nodes will be fused sima_logger.sima_log_info("Fusing node {} with a requantization node".format(node.name)) fusion_list.append((node, successor, fuse_result)) return fusion_list def _mutate_net_from_requantization_fusion_list( net: AwesomeNet, fusion_list: List[Tuple[AwesomeNode, AwesomeNode, AwesomeNode]] ): """ Mutates the graph according to the results of the node fusion that are defined by the entries given in the fusion list. :param net: Network that is being subject of the fusion. It is being mutated in this function. :param fusion_list: The list of entries defining the result of the fusion of the given nodes. Each entry is a tuple of (node, successor_node, fused_node). In the resulting network, each pair of node, successor_node is being replaced by the fused_node. :return: Mutated network containing fused nodes. """ # Replace nodes that were fused. # Verify that no node is involved in more than one fusion transformation. replaced_set = set() for (a, b, r) in fusion_list: assert a.name not in replaced_set assert b.name not in replaced_set del net.nodes[a.name] del net.nodes[b.name] replaced_set.add(a.name) replaced_set.add(b.name) net.nodes[r.name] = r # Replace references to old successor node by references to fuse_result rn = Renaming({successor.name: fuse_result.name for (_, successor, fuse_result) in fusion_list}) rename_mut_awesomenet(rn, net) # Recalculate structure of the network net.topological_sort() def _do_requantization_fusion_in_net(net: AwesomeNet): """ Do the fusion transformation on all nodes in net. Subgraphs are not handled in this function. """ fusion_list = _create_requantization_fusion_list(net) _mutate_net_from_requantization_fusion_list(net, fusion_list) def _create_constants_requantization_fusion_list(net: AwesomeNet) \ -> List[Tuple[AwesomeNode, AwesomeNode]]: """ Attempt fusion on each constant node in the net. For each (constant_node, successor_node) pair that the fusion is successful for, create a list entry consisting of a (successor_node, fused_node) tuple. Return the list of fusion entries for further processing. The algorithm is as follows: 1. Create a dictionary of {constant_node_name: requantization_node_names} entries defining requantization successors of each constant node in a subgraph. 2. Try fusing nodes with its successors. Only constant nodes and its successors that satisfy certain conditions given in the _NODE_FUSION_INTERFACE[ConstantOp] can be fused. :param net: AwesomeNet (subgraph) for which the fusion is attempted. :return: The list of (successor_node, fused_node) tuples defining the result of the fusion for a constant_node, successor_node pair. """ successors = _collect_successors(net, node_is_constant, node_is_requantization_node) # Attempt fusion on each constant node in this subgraph. # First, scan and find all constant nodes to fuse. # Items in fusion_list are (a, r) where r will replace a. fusion_list: List[Tuple[AwesomeNode, AwesomeNode]] = [] for node_name, successor_names in successors.items(): node = net.nodes[node_name] for idx, successor_name in enumerate(successor_names): successor = net.nodes[successor_name] fuse_result = _try_fuse_requantization_node(node, successor) if fuse_result is not None: # Rename the generated node as the constant node may be used in multiple # successors. fuse_result.name = f"{fuse_result.name}_requant_{idx}" # These nodes will be fused sima_logger.sima_log_info( f"Fusing node {node_name} with a requantization node {successor_name}" ) fusion_list.append((successor, fuse_result)) return fusion_list def _mutate_net_from_constants_requantization_fusion_list( net: AwesomeNet, fusion_list: List[Tuple[AwesomeNode, AwesomeNode]] ): """ Mutates the graph according to the results of the node fusion that are defined by the entries given in the fusion list. :param net: Network that is being subject of the fusion. It is being mutated in this function. :param fusion_list: The list of entries defining the result of the fusion of the given constant nodes. Each entry is a tuple of (requantization_node, fused_node). In the resulting network, each requantization_node is being replaced by the fused_node. :return: Mutated network containing fused nodes. """ # Replace nodes that were fused. # Verify that no node is involved in more than one fusion transformation. replaced_set = set() for (a, r) in fusion_list: assert a.name not in replaced_set del net.nodes[a.name] replaced_set.add(a.name) net.nodes[r.name] = r # Replace references to old successor nodes by references to fuse_result. rn = Renaming({successor.name: fuse_result.name for (successor, fuse_result) in fusion_list}) rename_mut_awesomenet(rn, net) # Remove constant nodes that have no successors in the model. # Constant nodes have no input nodes, so no renaming is needed here. successors = _collect_successors(net, node_is_constant) for node, successors in successors.items(): if len(successors) == 0: del net.nodes[node] # Recalculate structure of the network net.topological_sort() def _do_constant_requantization_fusion_in_net(net: AwesomeNet): """ Do the fusion transformation on all constant nodes in net. Subgraphs are not handled in this function. """ fusion_list = _create_constants_requantization_fusion_list(net) _mutate_net_from_constants_requantization_fusion_list(net, fusion_list) def _do_requantization_fusion(net: AwesomeNet): """ Do the fusion transformation on all nodes in net, including subgraphs. """ _do_constant_requantization_fusion_in_net(net) _do_requantization_fusion_in_net(net) # Do fusion in subgraphs for node_name, node in net.nodes.items(): if isinstance(node.ir, AwesomeNet): _do_requantization_fusion(node.ir)
[docs] class FuseRequantizations(BaseTransform): """ A compiler pass that does fusion on RequantizeOp nodes, such as the ones that are inserted during quantization. RequantizeOp nodes are fused into the preceding node where possible. """ def __call__(self, net: AwesomeNet) -> None: _do_requantization_fusion(net)