Source code for afe.core.mixed_precision.annotation

#########################################################
# Copyright (C) 2024 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Shreyas Kera
#########################################################
from typing import List, Tuple, Optional, Dict, Union

import onnx
from onnx.helper import make_node, make_opsetid


[docs] def remove_weight_quant(model: onnx.ModelProto) -> onnx.ModelProto: """ Removes weight quantizers in mct quantized model. :param model: Onnx model with weight holders. :return: model. Onnx model without weight holders. """ for node in model.graph.node: if node.op_type == 'WeightsSymmetricQuantizer': next_node, index = find_node_by_input(model, node.output[0]) # Connect next node's input to weight holder's weight next_node.input[index] = node.input[0] constant_name = node.input[1] # Remove weight holder model.graph.node.remove(node) constant_node = find_node_by_output(model, constant_name) # Remove constant associated with the weight holder model.graph.node.remove(constant_node) return model
[docs] def update_precision(model: onnx.ModelProto, promotion_list: List[str]) -> onnx.ModelProto: """ Updates precision of layers to promote in onnx annotated model. :param model: Onnx model with activation holders. :param promotion_list: List of layer names to promote to 16 bit. :return: model. Onnx model with activation holders with appropriate bit settings. """ for node in model.graph.node: if node.name in promotion_list: attr = onnx.helper.make_attribute("num_bits", 16) node.attribute[-1].CopyFrom(attr) return model
# Returned to indicate that a tensor's precision should be ignored
[docs] class IgnoreTensorPrecision: pass
[docs] def lookup_tensor_precision(model: onnx.ModelProto, precs: Dict[str, Optional[int]], tensor: str) \ -> Union[None, IgnoreTensorPrecision, int]: """ Find the precision that mixed precision search has chosen for a tensor according to attributes of Activation Holders. :param model: Onnx model with activation holders. :param precs: Precisions associated with nodes calculated so far. It holds at least the nodes that topologically precede the tensor. :param tensor: Tensor to look up. :return: Precision of the tensor. """ if tensor == "": return IgnoreTensorPrecision() if any(model_input.name == tensor for model_input in model.graph.input): # Inputs have int16 precision return 16 elif is_constant_tensor(model, tensor): # Constants are ignored return IgnoreTensorPrecision() else: # All other tensors must be an output of a node node_check = find_node_by_output(model, tensor) return precs[node_check.name]
[docs] def is_constant_tensor(model: onnx.ModelProto, tensor: str) -> bool: """ Return true if the tensor is an initializer or the output of a constant operator. """ if any(initializer.name == tensor for initializer in model.graph.initializer): return True for node in model.graph.node: if node.output[0] == tensor: # The tensor is the output of this node. Decide whether it is a constant operator. return node.op_type == 'Constant' return False
[docs] def get_redundant_holders(model: onnx.ModelProto) -> List[str]: """ Find redundant Activation Holders from topologically sorted onnx nodes, only those Activation Holders that indicate a precision switch should be kept. :param model: Onnx model with activation holders. :return: remove_list. List of redundant activation holders to remove. """ # Maintain a dictionary of nodes and corresponding precision precs = {} remove_list = [] for node in model.graph.node: if node.op_type == 'ActivationUniformQuantizer': node_input = node.input[0] precs[node.name] = int(node.attribute[-1].i) if any(model_input.name == node_input for model_input in model.graph.input) \ or is_constant_tensor(model, node_input): # Activation holder is on a constant or an input of the model. Remove it. remove_list.append(node.name) else: input_prec = lookup_tensor_precision(model, precs, node_input) if isinstance(input_prec, int) and input_prec == precs[node.name]: # Activation holder has the same precision as the input node. Remove it. remove_list.append(node.name) else: # If all the input precisions are the same, current node's precision is that value else indeterminate input_precs = [] for node_input in node.input: p = lookup_tensor_precision(model, precs, node_input) if not isinstance(p, IgnoreTensorPrecision): input_precs.append(p) if len(input_precs) and input_precs.count(input_precs[0]) == len(input_precs): precs[node.name] = input_precs[0] else: # Use None to denote indeterminate precision precs[node.name] = None return remove_list
[docs] def is_initializer(model: onnx.ModelProto, name: str) -> bool: """ Check if a given string is an initializer. :param model: Onnx model with activation holders. :param name: Name of initializer to search. :return: True if initializer else False """ for initializer in model.graph.initializer: if initializer.name == name: return True return False
[docs] def find_node_by_name(model: onnx.ModelProto, node_name: str) -> Optional[onnx.NodeProto]: """ Find node given node name. :param model: Onnx model with activation holders. :param node_name: Name of node to search. :return: node. Node if it is found else None """ for node in model.graph.node: if node.name == node_name: return node return None
[docs] def find_node_by_output(model: onnx.ModelProto, output_name: str) -> onnx.NodeProto: """ Find node given node output name. :param model: Onnx model with activation holders. :param output_name: Name of node output to search. :return: Node which output matches the output name. Raises an exception if node with specified output is not found. """ for node in list(model.graph.node): if node.output[0] == output_name: return node raise ValueError(f"Cannot find node with output named {output_name}.")
[docs] def find_node_by_input(model: onnx.ModelProto, input_name: str) -> Tuple[onnx.NodeProto, int]: """ Find an operator node whose input is the specified input_name. :param model: Loaded model in onnx.ModelProto representation. :param input_name: Name of the input. :return: Tuple of Node which input matches the input_name and input index. Raises an exception if node with specified input is not found. """ for node in list(model.graph.node): for i, node_input in enumerate(node.input): if node_input == input_name: return node, i raise ValueError(f"Cannot find node with input named {input_name}.")
[docs] def remove_node(model: onnx.ModelProto, node_name: str) -> None: """ Remove node from model. :param model: Onnx model with activation holders. :param node_name: Name of the node to remove. :return: None. """ node = find_node_by_name(model, node_name) assert node is not None and len(node.output) == 1 is_last_node = any(node.output[0] == x.name for x in list(model.graph.output)) true_input_to_removed_node = [name for name in node.input if not is_initializer(model, name)] assert len(true_input_to_removed_node) == 1 if is_last_node: connecting_node = find_node_by_output(model, true_input_to_removed_node[0]) for i, model_output in enumerate(model.graph.output): if model_output.name == node.output[0]: connecting_node.output[0] = model.graph.output[i].name following_nodes = [] for node_check in list(model.graph.node): for i, node_input in enumerate(node_check.input): if node_input == node.output[0]: following_nodes.append((i, node_check)) for input_idx, following_node in following_nodes: following_node.input[input_idx] = true_input_to_removed_node[0] model.graph.node.remove(node)
[docs] def replace_holders_with_annotations(model: onnx.ModelProto) -> None: """ Replace mct based activation holders with AFE specified annotations. :param model: Onnx model with activation holders. :return: None. """ nodes = list(model.graph.node) for i, node in enumerate(nodes): if node.op_type == 'ActivationUniformQuantizer': annot = make_node(op_type="AnnotatePrecision", inputs=node.input, outputs=node.output, domain="ai.sima", precision=f"int{node.attribute[-1].i}") model.graph.node.remove(node) model.graph.node.insert(i, annot)
[docs] def set_sima_opset(model: onnx.ModelProto) -> None: """ Include sima opset. :param model: Onnx model with activation holders. :return: None. """ opset = make_opsetid("ai.sima", 1) model.opset_import.append(opset)
[docs] def annotate_model(promotion_list: List[str], annotated_onnx_filename: str) -> None: """ Main function to annotate model by updating activation holder precisions, remove redundant holders so that only the ones indicating a precision change are kept, replace holders with AFE specific annotations and save onnx model. :param promotion_list: List of layer names to promote to 16 bit. :param annotated_onnx_filename: Onnx file path used to load the original model and save the new model. :return: None. """ model = onnx.load(annotated_onnx_filename) model = remove_weight_quant(model) for i in range(len(promotion_list)): promotion_list[i] = '/'+promotion_list[i]+'/ActivationUniformQuantizer' update_precision(model, promotion_list) remove_list = get_redundant_holders(model) for node_name in remove_list: remove_node(model, node_name) replace_holders_with_annotations(model) set_sima_opset(model) onnx.checker.check_model(model) onnx.save(model, annotated_onnx_filename)