Source code for afe.ir.transform.requantization_hoisting.hoisting_transform

#########################################################
# Copyright (C) 2024 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Christopher Rodrigues
#########################################################
"""
This module defines the requantization hoisting
transformation.  It moves Requantize nodes from their
original locations (usually at the input of a node) to
immediately after the node that produces its input data.
"""
from typing import Dict, List, Optional, Tuple
import dataclasses

import numpy as np
from ml_kernels.requantization import BaseRequantization
from sima_utils.logging.sima_logger import sima_log_info

from afe.backends import Backend
from afe.ir.build_node import create_requantization_node, create_tuple_get_item_nodes, create_tuple_output_node
from afe.ir.defines import DataValue, TensorValue, TupleValue, NodeName, get_expected_tensor_value, zip_data_value, \
    map_data_value, DataIndex, TensorIndex, TupleIndex, index_data_value, data_value_elements, get_expected_tuple_values
from afe.ir.net import AwesomeNet
from afe.ir.node import AwesomeNode, node_is_placeholder
from afe.ir.quantization_utils import requantize_quantization
from afe.ir.sima_ir import SiMaIR, SiMaIRMetadata
import afe.ir.operations as afe_op
import afe.ir.attributes as afe_attr
from afe.ir.tensor_type import TensorType, ScalarType
from afe.ir.transform.requantization_hoisting.defines import Need, NeedMapping, need_mapping_get, need_mapping_find, \
    need_mapping_singleton, need_mapping_insert, DataNeeds, need_mapping_empty


[docs] class RqRenaming: """ A renaming on requantizations. It maps (tensor, need) pairs in the old net to tensors in the new net. Where a model uses tuples, it has mappings for the tensors in the tuples. For a value named N in the old net, self.find(N) is a DataValue identifing the tensor values corresponding to N in the new net. As a tensor may have been requantized several ways, the corresponding tensor values are stored as a NeedMapping[NodeName]. """ # The Dict and NeedMapping objects are mutable _renaming: Dict[NodeName, DataValue[NeedMapping[NodeName]]] def __init__(self): self._renaming = {}
[docs] def find(self, n: NodeName) -> DataValue[NeedMapping[NodeName]]: """ Get all renaming information for the output of one node. It is an error if the node is not found. """ return self._renaming[n]
[docs] def find_tensor(self, n: NodeName, need: Need) -> NodeName: """ Find the new tensor representing the given node's output and need. """ data_value = self.find(n) assert isinstance(data_value, TensorValue) return need_mapping_find(data_value.value, need)
[docs] def get(self, n: NodeName) -> Optional[DataValue[NeedMapping[NodeName]]]: """ Get all renaming information for the output of one node, or None if there is no information. """ return self._renaming.get(n)
[docs] def get_tensor(self, n: NodeName, need: Need) -> Optional[NodeName]: """ Get the new tensor representing the given node's output and need, or None if there is no such tensor. """ data_value = self.get(n) if data_value is None: return None assert isinstance(data_value, TensorValue) return need_mapping_get(data_value.value, need)
[docs] def assign(self, n: NodeName, v: DataValue[NeedMapping[NodeName]]) -> None: """ Assign the renaming information for the output of one node. """ assert n not in self._renaming, f"Value is already assigned for node {n}" self._renaming[n] = v
[docs] def assign_tensor(self, n: NodeName, transformed_n: NodeName) -> None: """ Assign the renaming information for the output of one node, where the output is a single tensor and it is not requantized. """ self.assign(n, TensorValue(need_mapping_singleton(None, transformed_n)))
[docs] def update_tensor(self, n: NodeName, need: Need, transformed_n: NodeName) -> None: """ Record that transformed_n represents the single tensor output of `n` requantized by `need`. Node n must output a single tensor. Other renamings for node n are not affected. :param n: Node in old model that will be associated to a node in the new model :param need: Requantization that is performed on the output of n :param transformed_n: Node in new model that will be associated to n """ entry = self._renaming.get(n) if entry is None: entry = TensorValue([]) self._renaming[n] = entry assert isinstance(entry, TensorValue) need_mapping_insert(entry.value, need, transformed_n)
class _GlobalData: """ Data that is global over the scope of the transformation. """ # Needs that were computed by need analysis. # This value is read-only. _node_needs: Dict[NodeName, DataNeeds] # Needs that were computed on inputs of a subgraph by need analysis. # subgraph_inputs[node_name][parameter_index] is a list of needs # for that parameter index. The parameter list is in the same order as # the net's input_node_names field. # This value is read-only. _subgraph_inputs: Dict[NodeName, List[DataNeeds]] # Unique ID value for creating new node names _requant_node_id: int def __init__(self, node_needs: Dict[NodeName, DataNeeds], subgraph_inputs: Dict[NodeName, List[DataNeeds]]): self._node_needs = node_needs self._subgraph_inputs = subgraph_inputs self._requant_node_id = 0 def get_node_needs(self, name: NodeName) -> DataNeeds: """ Look up the needs for one node's output. The returned value has information about all requantizations that were applied to all of the node's output tensors in the original model. """ return self._node_needs[name] def get_subgraph_inputs(self, name: NodeName) -> List[DataNeeds]: """ Look up the needs for one subgraph's inputs. The returned value has information about all requantizations that were applied to all of the subgraph's inputs in the original model. The list has the same order as the subgraph's input node list. """ return self._subgraph_inputs[name] def new_requant_node_id(self) -> int: """ Get a new, unique ID to use for naming Requantize nodes. """ n = self._requant_node_id self._requant_node_id = n + 1 return n @dataclasses.dataclass class _NetConversionState: """ State associated with converting one AwesomeNet. """ # Renaming to apply to tensors in the AwesomeNet. # This is used to look up the new tensor that corresponds to a # given value from the old model. renaming: RqRenaming # Nodes to be included in the AwesomeNet after conversion. # This field is initialized to an empty list and is updated as # nodes are created for the new AwesomeNet. node_list: List[AwesomeNode] def _add_node(s: _NetConversionState, node: AwesomeNode) -> None: """ Insert one node into s's node list. """ s.node_list.append(node) def _add_nodes(s: _NetConversionState, nodes: List[AwesomeNode]) -> None: """ Insert a list of nodes into s's node list. """ for node in nodes: s.node_list.append(node) def _make_tuple_node(s: _NetConversionState, name_prefix: str, inputs: List[NodeName]) -> AwesomeNode: """ Create a tuple node having the given inputs. The node is added to s's node list and also returned. :param s: Net conversion state :param name_prefix: String to use in constructing the tuple node's name. Text will be appended to make the name. This string must be unique among all calls to _make_tuple_node for the model graph. :param inputs: Names of tuple's input nodes. These nodes must have been inserted into s's node list. :return: New tuple node """ node_name_map = _make_node_dict(s) nodes = [node_name_map[n] for n in inputs] tuple_node = create_tuple_output_node(nodes, name_prefix) _add_node(s, tuple_node) return tuple_node def _make_tuple_get_item_nodes(s: _NetConversionState, name_prefix: str, types: List[TensorType], input_node: NodeName) -> List[AwesomeNode]: """ Create nodes to get each item of the given tuple node. The nodes are inserted into s's node list and also returned. :param s: Net conversion state :param name_prefix: Prefix used to make new node names. The prefix must be unique across all calls to _make_tuple_get_item_nodes for the model graph. :param types: Types of the items in the tuple. :param input_node: Tuple node to get items from. :return: List of new nodes, in the same order as tuple elements. """ tuple_get_item_nodes = create_tuple_get_item_nodes(input_node, types, name_prefix) _add_nodes(s, tuple_get_item_nodes) return tuple_get_item_nodes def _get_node_output_type(s: _NetConversionState, node_name: NodeName) -> DataValue[TensorType]: """ Get the output type of the node having the given name. The node must have been inserted into s's node list. """ for node in s.node_list: if node.name == node_name: return node.get_type().output raise KeyError("No node with name " + node_name) def _make_node_dict(s: _NetConversionState) -> Dict[NodeName, AwesomeNode]: """ Prepare to put all of s's nodes into an AwesomeNet. Convert s's node list to a dict that can be passed to AwesomeNet's constructor. """ return {n.name: n for n in s.node_list} def _convert_placeholder(old_input: AwesomeNode, need: Need, tag: str) -> AwesomeNode: """ Convert a placeholder node. Return a new node that represents the tensor equivalent to `old_input`'s tensor requantized by `need`. :param old_input: Original node. It must be a placeholder. :param need: Requantization that is applied to the original node. The new placeholder represents the result of this requantization. :param tag: Tag to append to the node name. Appending the tag must produce a globally unique name in the model. :return: New node. """ assert node_is_placeholder(old_input) old_attrs = old_input.ir.quant_attrs if old_input.ir.quant_attrs is not None else old_input.ir.attrs assert isinstance(old_attrs, (afe_attr.PlaceholderAttrs, afe_attr.PlaceholderQuantAttrs)) old_quantization = old_attrs.quantization if isinstance(old_attrs, afe_attr.PlaceholderQuantAttrs) else None if need is None: # The new placeholder has same information as the old placeholder new_scalar_type = old_attrs.type.scalar new_quantization = old_quantization else: # The new placeholder's data is the output of a requantization node. # When need != None, it implies that the old quantization is known, # so we can use this info to calculate the new quantization. assert old_quantization is not None new_scalar_type = ScalarType.from_numpy(need.out_dtype) new_quantization = requantize_quantization(old_quantization, need) new_type = TensorType(new_scalar_type, old_attrs.type.shape) new_ir = SiMaIR(afe_op.PlaceholderOp(), None, afe_attr.AwesomeCalibAttrs(), afe_attr.PlaceholderQuantAttrs(new_type, new_quantization), old_input.ir.quant_config, Backend.NONE, SiMaIRMetadata(old_input.name)) new_name = NodeName(old_input.name + tag) # The old value of _layer_stats remains valid if no requantization was applied layer_stats = old_input.layer_stats if need is None else None return AwesomeNode(new_name, afe_op.PlaceholderOp.input_list, [new_name], new_ir, _status=old_input.status, _layer_stats=layer_stats) def _log_info_for_convert_subgraph(old_node: AwesomeNode, new_node: AwesomeNode): """ Print information about conversion results for debugging. """ assert isinstance(old_node.ir, AwesomeNet) assert isinstance(new_node.ir, AwesomeNet) ph_strings = [] for ph_name in old_node.ir.input_node_names: ph = old_node.ir.nodes[ph_name] ph_strings.append(f"{ph.name}: {ph.get_type().output}") sima_log_info("Old placeholders: " + ", ".join(ph_strings)) sima_log_info("Old result: " + str(old_node.ir.nodes[old_node.ir.output_node_name].get_type().output)) ph_strings = [] for ph_name in new_node.ir.input_node_names: ph = new_node.ir.nodes[ph_name] ph_strings.append(f"{ph.name}: {ph.get_type().output}") sima_log_info("New placeholders: " + ", ".join(ph_strings)) sima_log_info("New result: " + str(new_node.ir.nodes[new_node.ir.output_node_name].get_type().output)) def _convert_subgraph(global_data: _GlobalData, state: _NetConversionState, old_node: AwesomeNode) -> None: """ Convert a subgraph. The result of conversion is saved in state. :param old_node: Node to convert. It must contain an AwesomeNet. :param state: Conversion data for the network that contains old_node. It will be updated with the result of conversion. :param global_data: Global data for this transformation. """ net_node_name = old_node.name old_argument_list = old_node.input_node_names net = old_node.ir assert isinstance(net, AwesomeNet) # Create placeholders for the subgraph's inputs. # Rename placeholders. node_list: List[AwesomeNode] = [] placeholder_name_list: List[NodeName] = [] new_argument_list: List[NodeName] = [] subgraph_renaming = RqRenaming() subgraph_inputs = global_data.get_subgraph_inputs(net_node_name) assert len(net.input_node_names) == len(old_argument_list) == len(subgraph_inputs) for old_input_name, old_argument, input_needs in zip(net.input_node_names, old_argument_list, subgraph_inputs): assert isinstance(input_needs, TensorValue) for j, input_need in enumerate(input_needs.value): tag = "_" + str(j) old_input = net.nodes[old_input_name] new_placeholder = _convert_placeholder(old_input, input_need, tag) new_argument = state.renaming.find_tensor(old_argument, input_need) node_list.append(new_placeholder) placeholder_name_list.append(new_placeholder.name) new_argument_list.append(new_argument) subgraph_renaming.update_tensor(old_input_name, input_need, new_placeholder.name) # Convert nodes inside the subgraph local_state = _NetConversionState(renaming=subgraph_renaming, node_list=node_list) _convert_subgraph_nodes(global_data, local_state, net) # Decide what will be the output of the converted subgraph and record it in subgraph_outputs. # A list of 1 or more tensors is used to describe one tensor (if its length is 1) or # a tuple of 2 or more tensors (otherwise). # When subgraph_outputs[i] = (j, need, name), this means that the new output at # position 'i' has the same value as taking the old output at position 'j' and # requantizing it by 'need', and also that this output receives its value from # tensor 'name' in the subgraph. subgraph_outputs: List[Tuple[DataIndex, Need, NodeName]] = [] output_values: DataValue[NeedMapping[NodeName]] = local_state.renaming.find(net.output_node_name) match zip_data_value(lambda x, y: (x, y), output_values, global_data.get_node_needs(net_node_name)): case TensorValue((output_tensor_value, output_tensor_needs)): for need in output_tensor_needs: subgraph_outputs.append((TensorIndex(), need, need_mapping_find(output_tensor_value, need))) case TupleValue(elements): for i, element_value in enumerate(elements): assert isinstance(element_value, TensorValue) output_tensor_value, output_tensor_needs = element_value.value for need in output_tensor_needs: subgraph_outputs.append((TupleIndex(i, TensorIndex()), need, need_mapping_find(output_tensor_value, need))) assert len(subgraph_outputs) > 0, "Subgraph has no outputs after transformation" if len(subgraph_outputs) > 1: # Handle tuple and tuple_get_item nodes for multiple outputs. # Create a tuple node for the output, create tuple_get_item nodes in the parent graph, # and insert the tuple_get_item nodes into the renaming for the parent graph. output_tuple_item_names = [name for _, _, name in subgraph_outputs] output_types = [get_expected_tensor_value(_get_node_output_type(local_state, name)) for name in output_tuple_item_names] output_node = _make_tuple_node(local_state, net_node_name, output_tuple_item_names) tuple_get_item_nodes = _make_tuple_get_item_nodes(state, net_node_name, output_types, net_node_name) # Create a key-value list mapping old subgraph's outputs to new subgraph's outputs output_mapping_entries: List[Tuple[DataIndex, Need, NodeName]] = [] for (data_index, need, _), output_tuple_node in zip(subgraph_outputs, tuple_get_item_nodes): output_mapping_entries.append((data_index, need, output_tuple_node.name)) output_node_name = output_node.name else: # The subgraph has a single output. Add this single output to the renaming for the parent graph. data_index, need, output_node_name = subgraph_outputs[0] output_mapping_entries = [(data_index, need, net_node_name)] # Insert the output renaming info into parent net's state parent_output_value_renaming: DataValue[NeedMapping[NodeName]] = \ map_data_value(lambda _: need_mapping_empty(), output_values) for (data_index, need, name) in output_mapping_entries: need_mapping_insert(index_data_value(parent_output_value_renaming, data_index), need, name) state.renaming.assign(net_node_name, parent_output_value_renaming) # Add the new subgraph to the parent net. # Fields _execution_order, _prune_dict, _float_node_list, _output_labels, and _fp_input_range # are discarded because the nodes they refer to may no longer exist. new_net = AwesomeNet(name=net.name, nodes=_make_node_dict(local_state), input_node_names=placeholder_name_list, output_node_name=output_node_name, _status=net._status, _is_subgraph=True, _backend=net._backend, _target=net._target) new_node = AwesomeNode(name=net_node_name, input_names=placeholder_name_list, input_node_names=new_argument_list, ir=new_net, _status=old_node._status) # _print_info_for_convert_subgraph(old_node, new_node) _add_node(state, new_node) def _convert_subgraph_nodes(global_data: _GlobalData, state: _NetConversionState, net: AwesomeNet) -> None: """ Convert all nodes in a net except for placeholders. The state must be initialized to represent the state at the beginning of the net, including information about placeholders. """ net.topological_sort() for node_name in net.execution_order: node = net.nodes[node_name] if node_is_placeholder(node): continue _convert_node(global_data, state, node) def _convert_sima_ir_node(global_data: _GlobalData, state: _NetConversionState, node: AwesomeNode) -> None: """ Convert a SiMaIR node. This function is used when no special-case conversion function is applicable. The node's output type must be a tensor. Conversion inserts requantization nodes to transform the node's output, and it makes the node take its inputs from previously created nodes. The results of conversion are saved in state. """ assert isinstance(node.ir, SiMaIR) # Find the node's new inputs. Assume the inputs are tensors. input_node_names = [state.renaming.find_tensor(n, None) for n in node.input_node_names] # Create the new node. Only the inputs have changed. new_node = dataclasses.replace(node, input_node_names=input_node_names) _add_node(state, new_node) # Create nodes to transform the new node's output value if isinstance(node.ir.calib_attrs.quant, TupleValue): # The node's output is a tuple. Extract and record the individual tensors. # Since the quantizer does not quantize nodes that output a tuple, no requantization nodes are created. # Verify that there is no requantization associated with this node's output. assert all(need is None for needs in data_value_elements(global_data.get_node_needs(node.name)) for need in needs) # Reconstruct TupleGetItem nodes node_output_types = get_expected_tuple_values(node.get_type().output) tuple_get_item_nodes = create_tuple_get_item_nodes(node.name, node_output_types, node.name + "/requant") # Record the new nodes _add_nodes(state, tuple_get_item_nodes) # Associate new output to original output output_need_mapping = TupleValue([TensorValue(need_mapping_singleton(None, tgi_node.name)) for tgi_node in tuple_get_item_nodes]) state.renaming.assign(node.name, output_need_mapping) else: # Create requantization node and update renaming for each need output_qrtt = get_expected_tensor_value(node.ir.calib_attrs.quant) output_quant = output_qrtt.quant node_type = get_expected_tensor_value(node.get_type().output) state.renaming.update_tensor(node.name, None, node.name) # Add a mapping for directly using the node's output for need in get_expected_tensor_value(global_data.get_node_needs(node.name)): if need is None: continue # If the data was requantized, the quantization must be known. # Use it to calculate the quantization we will have at the output of requantization. assert isinstance(need, BaseRequantization) assert output_quant is not None, "Attempted to requantize a tensor having unknown quantization" id_num = global_data.new_requant_node_id() requantized_quant = requantize_quantization(output_quant, need) requant_method = output_qrtt.requant_method assert requant_method is not None rq_node = create_requantization_node(node.name, id_num, node_type, output_quant, requantized_quant, requant_method, need) _add_node(state, rq_node) state.renaming.update_tensor(node.name, need, rq_node.name) def _empty_need_mapping(output: DataValue[TensorType]) -> DataValue[NeedMapping[NodeName]]: """ Create need information for a value of the given type that is not used. """ return map_data_value(lambda _: list(), output) def _rq_renaming_find_with_default(r: RqRenaming, n: NodeName, default_value: DataValue[NeedMapping[NodeName]]) \ -> DataValue[NeedMapping[NodeName]]: x = r.get(n) if x is not None: return x else: return default_value def _convert_node(global_data: _GlobalData, state: _NetConversionState, node: AwesomeNode) -> None: """ Convert a node to use new requantizations. If it is a type of node that will be reconstructed based on analysis results, then the conversion state is updated to record the mapping from old node to new nodes, and the node is discarded. Otherwise, the node's inputs are renamed and requantizations are inserted for the node's outputs. :param global_data: Global information used by the transformation :param state: Transformation state for the current AwesomeNet :param node: Node to convert """ if isinstance(node.ir, SiMaIR): attrs = node.ir.quant_attrs if node.ir.quant_attrs is not None else node.ir.attrs match node.ir.operation: case afe_op.PlaceholderOp(): raise TypeError("Placeholder should not be handled here") case afe_op.RequantizeOp(): # The algorithm created a replacement for this node if it was needed. # Look up this node's replacement and update the renaming. assert isinstance(attrs, afe_attr.RequantizeQuantAttrs) input_node_name = state.renaming.get_tensor(node.input_node_names[0], attrs.requant) if input_node_name is not None: state.renaming.assign_tensor(node.name, input_node_name) case afe_op.TupleOp(): # The algorithm will recreate tuple nodes if needed. # Update the renaming for this node's output. assert isinstance(attrs, afe_attr.TupleAttrs) empty_need = _empty_need_mapping(node.get_type().output) assert isinstance(empty_need, TupleValue) input_needs_list: List[DataValue[NeedMapping[NodeName]]] input_needs_list = [_rq_renaming_find_with_default(state.renaming, n, default_value) for n, default_value in zip(node.input_node_names, empty_need.elements)] state.renaming.assign(node.name, TupleValue(input_needs_list)) case afe_op.TupleGetItemOp(): # The algorithm will recreate tuple get item nodes if needed. # Update the renaming for this node's output. assert isinstance(attrs, afe_attr.TupleGetItemAttrs) empty_need = _empty_need_mapping(node.get_type().output) input_needs = _rq_renaming_find_with_default(state.renaming, node.input_node_names[0], empty_need) assert isinstance(input_needs, TupleValue) node_needs = input_needs.elements[attrs.index] state.renaming.assign(node.name, node_needs) case _: _convert_sima_ir_node(global_data, state, node) elif isinstance(node.ir, AwesomeNet): _convert_subgraph(global_data, state, node) else: raise TypeError("Unexpected node type")
[docs] def move_requantization_in_model_graph(net: AwesomeNet, node_needs: Dict[NodeName, DataNeeds], subgraph_inputs: Dict[NodeName, List[DataNeeds]]) -> AwesomeNet: """ Place all Requantize nodes as early as possible in the model graph. All requantize nodes are removed, and the requantize nodes which are needed are reconstructed in the same location as the node that produced their input. :param net: Model to transform :param node_needs: Uses of the outputs of all nodes, as computed by analyze_needs :param subgraph_inputs: Uses of the inputs of all subgraphs, as computed by analyze_needs :return: Transformed model """ global_data = _GlobalData(node_needs, subgraph_inputs) renaming = RqRenaming() new_nodes = [] # Create new placeholders. These are identical to the old placeholders. for placeholder_name in net.input_node_names: new_nodes.append(_convert_placeholder(net.nodes[placeholder_name], None, "")) renaming.assign_tensor(placeholder_name, placeholder_name) state = _NetConversionState(renaming, new_nodes) _convert_subgraph_nodes(global_data, state, net) # Find the output node. Make a tuple node if needed. match state.renaming.find(net.output_node_name): case TensorValue(need_mapping): output_node_name = need_mapping_find(need_mapping, None) case TupleValue(need_mappings): # Construct a tuple node for the output output_element_names = [need_mapping_find(get_expected_tensor_value(m), None) for m in need_mappings] output_node = _make_tuple_node(state, net.name, output_element_names) output_node_name = output_node.name case _: raise TypeError("Unexpected DataValue type") # _execution_order, _prune_dict, _float_node_list are discarded as node names may have changed. # Input node names have not changed. return AwesomeNet(name=net.name, nodes=_make_node_dict(state), input_node_names=net.input_node_names, output_node_name=output_node_name, _status=net._status, _is_subgraph=False, _backend=net._backend, _target=net._target, _output_labels=net._output_labels, _model_path=net._model_path, _fp_input_range=net._fp_input_range)