#########################################################
# Copyright (C) 2024 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Shreyas Kera
#########################################################
from typing import List, Tuple, Optional, Dict, Union
import onnx
from onnx.helper import make_node, make_opsetid
[docs]
def remove_weight_quant(model: onnx.ModelProto) -> onnx.ModelProto:
"""
Removes weight quantizers in mct quantized model.
:param model: Onnx model with weight holders.
:return: model. Onnx model without weight holders.
"""
for node in model.graph.node:
if node.op_type == 'WeightsSymmetricQuantizer':
next_node, index = find_node_by_input(model, node.output[0])
# Connect next node's input to weight holder's weight
next_node.input[index] = node.input[0]
constant_name = node.input[1]
# Remove weight holder
model.graph.node.remove(node)
constant_node = find_node_by_output(model, constant_name)
# Remove constant associated with the weight holder
model.graph.node.remove(constant_node)
return model
[docs]
def update_precision(model: onnx.ModelProto, promotion_list: List[str]) -> onnx.ModelProto:
"""
Updates precision of layers to promote in onnx annotated model.
:param model: Onnx model with activation holders.
:param promotion_list: List of layer names to promote to 16 bit.
:return: model. Onnx model with activation holders with appropriate bit settings.
"""
for node in model.graph.node:
if node.name in promotion_list:
attr = onnx.helper.make_attribute("num_bits", 16)
node.attribute[-1].CopyFrom(attr)
return model
# Returned to indicate that a tensor's precision should be ignored
[docs]
class IgnoreTensorPrecision:
pass
[docs]
def lookup_tensor_precision(model: onnx.ModelProto, precs: Dict[str, Optional[int]], tensor: str) \
-> Union[None, IgnoreTensorPrecision, int]:
"""
Find the precision that mixed precision search has chosen for a tensor according to
attributes of Activation Holders.
:param model: Onnx model with activation holders.
:param precs: Precisions associated with nodes calculated so far. It holds at least the nodes that
topologically precede the tensor.
:param tensor: Tensor to look up.
:return: Precision of the tensor.
"""
if tensor == "":
return IgnoreTensorPrecision()
if any(model_input.name == tensor for model_input in model.graph.input):
# Inputs have int16 precision
return 16
elif is_constant_tensor(model, tensor):
# Constants are ignored
return IgnoreTensorPrecision()
else:
# All other tensors must be an output of a node
node_check = find_node_by_output(model, tensor)
return precs[node_check.name]
[docs]
def is_constant_tensor(model: onnx.ModelProto, tensor: str) -> bool:
"""
Return true if the tensor is an initializer or the output of a constant operator.
"""
if any(initializer.name == tensor for initializer in model.graph.initializer):
return True
for node in model.graph.node:
if node.output[0] == tensor:
# The tensor is the output of this node. Decide whether it is a constant operator.
return node.op_type == 'Constant'
return False
[docs]
def get_redundant_holders(model: onnx.ModelProto) -> List[str]:
"""
Find redundant Activation Holders from topologically sorted onnx nodes, only those Activation Holders
that indicate a precision switch should be kept.
:param model: Onnx model with activation holders.
:return: remove_list. List of redundant activation holders to remove.
"""
# Maintain a dictionary of nodes and corresponding precision
precs = {}
remove_list = []
for node in model.graph.node:
if node.op_type == 'ActivationUniformQuantizer':
node_input = node.input[0]
precs[node.name] = int(node.attribute[-1].i)
if any(model_input.name == node_input for model_input in model.graph.input) \
or is_constant_tensor(model, node_input):
# Activation holder is on a constant or an input of the model. Remove it.
remove_list.append(node.name)
else:
input_prec = lookup_tensor_precision(model, precs, node_input)
if isinstance(input_prec, int) and input_prec == precs[node.name]:
# Activation holder has the same precision as the input node. Remove it.
remove_list.append(node.name)
else:
# If all the input precisions are the same, current node's precision is that value else indeterminate
input_precs = []
for node_input in node.input:
p = lookup_tensor_precision(model, precs, node_input)
if not isinstance(p, IgnoreTensorPrecision):
input_precs.append(p)
if len(input_precs) and input_precs.count(input_precs[0]) == len(input_precs):
precs[node.name] = input_precs[0]
else:
# Use None to denote indeterminate precision
precs[node.name] = None
return remove_list
[docs]
def is_initializer(model: onnx.ModelProto, name: str) -> bool:
"""
Check if a given string is an initializer.
:param model: Onnx model with activation holders.
:param name: Name of initializer to search.
:return: True if initializer else False
"""
for initializer in model.graph.initializer:
if initializer.name == name:
return True
return False
[docs]
def find_node_by_name(model: onnx.ModelProto, node_name: str) -> Optional[onnx.NodeProto]:
"""
Find node given node name.
:param model: Onnx model with activation holders.
:param node_name: Name of node to search.
:return: node. Node if it is found else None
"""
for node in model.graph.node:
if node.name == node_name:
return node
return None
[docs]
def find_node_by_output(model: onnx.ModelProto, output_name: str) -> onnx.NodeProto:
"""
Find node given node output name.
:param model: Onnx model with activation holders.
:param output_name: Name of node output to search.
:return: Node which output matches the output name.
Raises an exception if node with specified output is not found.
"""
for node in list(model.graph.node):
if node.output[0] == output_name:
return node
raise ValueError(f"Cannot find node with output named {output_name}.")
[docs]
def remove_node(model: onnx.ModelProto, node_name: str) -> None:
"""
Remove node from model.
:param model: Onnx model with activation holders.
:param node_name: Name of the node to remove.
:return: None.
"""
node = find_node_by_name(model, node_name)
assert node is not None and len(node.output) == 1
is_last_node = any(node.output[0] == x.name for x in list(model.graph.output))
true_input_to_removed_node = [name for name in node.input if not is_initializer(model, name)]
assert len(true_input_to_removed_node) == 1
if is_last_node:
connecting_node = find_node_by_output(model, true_input_to_removed_node[0])
for i, model_output in enumerate(model.graph.output):
if model_output.name == node.output[0]:
connecting_node.output[0] = model.graph.output[i].name
following_nodes = []
for node_check in list(model.graph.node):
for i, node_input in enumerate(node_check.input):
if node_input == node.output[0]:
following_nodes.append((i, node_check))
for input_idx, following_node in following_nodes:
following_node.input[input_idx] = true_input_to_removed_node[0]
model.graph.node.remove(node)
[docs]
def replace_holders_with_annotations(model: onnx.ModelProto) -> None:
"""
Replace mct based activation holders with AFE specified annotations.
:param model: Onnx model with activation holders.
:return: None.
"""
nodes = list(model.graph.node)
for i, node in enumerate(nodes):
if node.op_type == 'ActivationUniformQuantizer':
annot = make_node(op_type="AnnotatePrecision",
inputs=node.input,
outputs=node.output,
domain="ai.sima",
precision=f"int{node.attribute[-1].i}")
model.graph.node.remove(node)
model.graph.node.insert(i, annot)
[docs]
def set_sima_opset(model: onnx.ModelProto) -> None:
"""
Include sima opset.
:param model: Onnx model with activation holders.
:return: None.
"""
opset = make_opsetid("ai.sima", 1)
model.opset_import.append(opset)
[docs]
def annotate_model(promotion_list: List[str], annotated_onnx_filename: str) -> None:
"""
Main function to annotate model by updating activation holder precisions, remove redundant holders
so that only the ones indicating a precision change are kept, replace holders with AFE specific annotations
and save onnx model.
:param promotion_list: List of layer names to promote to 16 bit.
:param annotated_onnx_filename: Onnx file path used to load the original model and save the new model.
:return: None.
"""
model = onnx.load(annotated_onnx_filename)
model = remove_weight_quant(model)
for i in range(len(promotion_list)):
promotion_list[i] = '/'+promotion_list[i]+'/ActivationUniformQuantizer'
update_precision(model, promotion_list)
remove_list = get_redundant_holders(model)
for node_name in remove_list:
remove_node(model, node_name)
replace_holders_with_annotations(model)
set_sima_opset(model)
onnx.checker.check_model(model)
onnx.save(model, annotated_onnx_filename)