#########################################################
# Copyright (C) 2024 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Praneeth Medepalli
#########################################################
import numpy as np
from dataclasses import dataclass
from sima_utils.logging import sima_logger
from afe.ir.defines import NodeName
from afe.ir.net import AwesomeNet
from afe.ir.utils import is_depthwise_conv, is_group_conv
from afe.ir.node import AwesomeNode, node_is_awesomenet, node_is_sima_ir
from afe.ir.operations import ConvAddActivationOp, TupleConcatenateOp, MaxPool2DOp, AvgPool2DOp, AddActivationOp
from enum import Enum
[docs]
EPSILON = 1e-30 # To avoid division by zero errors
[docs]
SMOOTHQ_MINVAL = 1e-2 # To limit excess scaling
[docs]
EQUALIZATION_MINVAL = 1.0 # To limit excess scaling
[docs]
EQUALIZATION_MAXVAL = 20
# Enum class to represent scaling status of a given concat pair
[docs]
class ScaleStatus(Enum):
@dataclass
[docs]
class PairingSet:
"""
A parent-children pairing. It contains the identified parents, children, and
a status on whether the set has been scaled.
"""
[docs]
parents: list[AwesomeNode]
[docs]
children: list[AwesomeNode]
@dataclass
[docs]
class ResidualAddFamily:
"""
A collection of parents and children connected to an add-node. It contains the
add-node, its parents and children, and a flag indicating if the node has been
visited.
"""
[docs]
parents: list[AwesomeNode]
[docs]
children: list[AwesomeNode]
def _get_child_nodes(net: AwesomeNet, curr_node: NodeName) -> tuple[list[AwesomeNode], bool]:
"""
Find all the ConvAddActivation node and AddActivation node successors to the specified node.
If an add node is detected as a child, update a flag indicating an add node is present. If a Max
or Average Pooling operator is detected, recurse down to analyze its children. (As the pooling
nodes are linear and separable across channels, their children can be treated as curr_node's
children.)
:param net: Net to analyze.
:param curr_node: The current node whose children are analyzed.
:return: List of children nodes to the current node, and a flag indicating the presence of an add node.
"""
children_nodes = []
add_node_present = 0
for node in net.nodes.values():
if curr_node in node.input_node_names:
# If child node has multiple parents, store node only if its valid
if len(node.input_node_names) == 1:
# If child is of valid operation type, store in child node, else recurse further down
# to find conv child node
if isinstance(node.ir.operation, (MaxPool2DOp, AvgPool2DOp)):
updated_children_nodes, updated_add_nodes = _get_child_nodes(net, node.name)
children_nodes.extend(updated_children_nodes)
add_node_present += updated_add_nodes
elif isinstance(node.ir.operation, ConvAddActivationOp):
children_nodes.append(node)
else:
children_nodes.append(None) # If child is anything but conv, store none
else:
if isinstance(node.ir.operation, AddActivationOp):
add_node_present += 1
children_nodes.append(node)
else:
children_nodes.append(None) # If child is anything but an add-node, store none
return children_nodes, add_node_present
def _get_parent_nodes(net: AwesomeNet, curr_node: AwesomeNode) -> tuple[list[AwesomeNode], bool]:
"""
Find all the ConvAddActivation node and AddActivation node predecessors to the specified node.
If an add node is detected as a parent, update a flag indicating an add node is present. If a Max
or Average Pooling operator is detected, recurse up to analyze its parents. (As the pooling
operators are linear and separable across channels, their parents can be treated as curr_node's
parents.)
:param net: Net to analyze.
:param curr_node: The current node whose parents are analyzed.
:return: List of parent nodes to the current node, and a flag indicating the presence of an add node.
"""
parent_nodes = []
add_node_present = 0
for input_node_name in curr_node.input_node_names:
node = net.nodes[input_node_name]
# If node is of valid operation type to recurse up the network, otherwise
# add to parent list. If node is not valid, return None
if isinstance(node.ir.operation, (MaxPool2DOp, AvgPool2DOp)):
updated_parent_nodes, updated_add_nodes = _get_parent_nodes(net, node)
parent_nodes.extend(updated_parent_nodes)
add_node_present += updated_add_nodes
elif isinstance(node.ir.operation, ConvAddActivationOp):
parent_nodes.append(node)
elif isinstance(node.ir.operation, AddActivationOp):
add_node_present += 1
parent_nodes.append(node)
else:
parent_nodes.append(None)
return parent_nodes, add_node_present
[docs]
def get_pairing_lists(net: AwesomeNet) -> tuple[list[PairingSet], list[PairingSet], list[ResidualAddFamily]]:
"""
Given an AwesomeNet, generate valid parent-children pairings that fit various
patterns. The patterns are: 1. single-parent conv nodes with exclusively conv
children. 2. Concat nodes with exclusively conv parents and children. 3. Residual
add connections with exclusively conv parents and children.
:param net: AwesomeNet to iterate through.
:return: Tuple of Lists corresponding to each pattern. Each element of a list contains
a pairing sets of parents and children of a pattern, and status for patterns 1 & 2, and
a visited flag for pattern 3.
"""
conv_node_pairing_list: list[PairingSet] = []
concat_node_pairing_list: list[PairingSet] = []
add_node_family_list: list[ResidualAddFamily] = []
# Topological sort to ensure pair pass is done in operation execution order
net.topological_sort()
for node_name in net.execution_order:
node = net.nodes[node_name]
if node_is_awesomenet(node):
pairing_tuple = get_pairing_lists(node.ir)
conv_node_pairing_list.extend(pairing_tuple[0])
concat_node_pairing_list.extend(pairing_tuple[1])
add_node_family_list.extend(pairing_tuple[2])
continue
assert node_is_sima_ir(node)
# If first node is an add op, find its immediate children and parents (that are not pooling). If there is
# a non-conv node detected in the pairing set, add None to the list. Later, when residual sets are formed,
# when a None is detected, the residual set is discarded from scaling.
if isinstance(node.ir.operation, AddActivationOp):
parent_nodes, _ = _get_parent_nodes(net, node)
children_nodes, _ = _get_child_nodes(net, node.name)
# Do a check for other branches of children conv nodes from parent convs:
for parent in parent_nodes:
if parent is not None and isinstance(parent.ir.operation, ConvAddActivationOp):
# Will contain Nones if there are invalid children
branch_children, _ = _get_child_nodes(net, parent.name)
# Add all other children of conv node to list. Skip add-nodes to avoid infinite recursion.
for b_node in branch_children:
# b_node may be None, so we need to check that first
if b_node is None or not isinstance(b_node.ir.operation, AddActivationOp):
children_nodes.append(b_node)
add_family = ResidualAddFamily(node, parent_nodes, children_nodes)
add_node_family_list.append(add_family)
# If first node is a concatenate op, look for its next immediate conv children and conv parents. If there is
# a non-conv node detected in the pairing set, the entire pairing set is discarded.
elif isinstance(node.ir.operation, TupleConcatenateOp):
# If concatenation axis is not along channel axis, then don't create a pairing set
if node.ir.attrs.concat_attrs.axis != 3:
continue
children_conv_nodes, num_add_nodes_children = _get_child_nodes(net, node_name)
parent_conv_nodes, num_add_nodes_parents = _get_parent_nodes(net, node)
add_node_present = (num_add_nodes_children > 0) or (num_add_nodes_parents > 0)
if None in parent_conv_nodes or None in children_conv_nodes or add_node_present \
or len(children_conv_nodes) == 0 or len(parent_conv_nodes) == 0:
continue
pairing = PairingSet(parent_conv_nodes, children_conv_nodes, ScaleStatus.UNSCALED)
concat_node_pairing_list.append(pairing)
# If first node is a conv op, look for a potential pair. If there is
# a non-conv node detected in the pairing set, the entire pairing set is discarded.
elif isinstance(node.ir.operation, ConvAddActivationOp):
parent_node = node
children_conv_nodes, add_node_present = _get_child_nodes(net, node_name)
if None in children_conv_nodes or add_node_present > 0 or len(children_conv_nodes) == 0:
continue
else:
pairing = PairingSet([parent_node], children_conv_nodes, ScaleStatus.UNSCALED)
conv_node_pairing_list.append(pairing)
return conv_node_pairing_list, concat_node_pairing_list, add_node_family_list
[docs]
def get_pairings(net: AwesomeNet) -> tuple[list[PairingSet], list[PairingSet], list[PairingSet]]:
"""
Control function to find parent-children pairings to scale, and do necessary
post-processing to generate the final and complete set of pairings. The goal of this pairing framework is
to assemble self-contained parent-children 'pairing sets' that need to be scaled, such that each pairing
set can be scaled without affecting the input activations to the set or the output activations from the set
branching to other nodes in the network. Note that a node can appear in multiple pairing sets, but is always
scaled at most once as a parent and at most once as a child. Thus, this framework only guarantees that activations
going in and coming out of a set are invariant to scaling within a set, but does not guarantee that weight
changes are self-contained to a single set.
:param net: Network that is traversed to find node-pairs.
:return: Tuple of Lists of parent-children pairing fitting each of the above listed patterns.
"""
conv_node_pairing_list, concat_node_pairing_list, add_node_family_list = get_pairing_lists(net)
# Search for add connections that connect together, and pool their parents and children.
# Additionally, check that the residual sets are valid to scale.
residual_node_pairing_list = residual_pair_pass(add_node_family_list)
return conv_node_pairing_list, concat_node_pairing_list, residual_node_pairing_list
[docs]
def residual_pair_pass(add_node_family_list: list[ResidualAddFamily]) -> list[PairingSet]:
"""
Iterate through the net to identify all the residual connection groupings, and identify all the parent and
children nodes corresponding to each grouping. Each grouping will need to be scaled together with the same
scale due to the design of the residual connection. A grouping is only valid if all its parents and children
are convolution nodes (or pooling operators that lead to convolutions). Therefore, when a None is detected, the
entire residual set is assembled and discarded from scaling.
This algorithm is implemented by first identifying all the add operators in the network, and its immediate parents
and children. It then scans each add-node's connections for further add-nodes, and recurses along the add-nodes
connected together, pooling the corresponding parents and children together. After this, if even one parent/child
is not a conv, the entire grouping is ineligible for scaling.
:param add_node_family_list: List of pairing sets of add-nodes, its parents and children, and a visited flag for
downstream use.
:return: List of pairing sets, each containing the parents, children, and ScaleStatus of a residual grouping.
"""
final_node_pairing_list: list[PairingSet] = []
# Iterate through the list of add-nodes, and group valid residual connections together.
for add_node_idx, add_node_family in enumerate(add_node_family_list):
if not add_node_family.visited:
final_parents_list, final_children_list = get_add_connections(add_node_family_list, add_node_idx)
# Only add to final pairings if all the parents and children are valid
if (None not in final_parents_list) and (None not in final_children_list) and \
len(final_parents_list) != 0 and len(final_children_list) != 0:
# Check that all nodes are convs
num_conv_parents = sum(isinstance(p.ir.operation, ConvAddActivationOp) for p in final_parents_list)
num_conv_children = sum(isinstance(c.ir.operation, ConvAddActivationOp) for c in final_children_list)
assert num_conv_parents == len(final_parents_list), 'Non-conv nodes detected in final parents list'
assert num_conv_children == len(final_children_list), 'Non-conv nodes detected in final children list'
pairing = PairingSet(final_parents_list, final_children_list, ScaleStatus.UNSCALED)
final_node_pairing_list.append(pairing)
return final_node_pairing_list
[docs]
def get_new_add_index(add_node_family_list: list[ResidualAddFamily], add_node_to_find: AwesomeNode) -> int:
"""
Given a list of pairing sets containing add-nodes, identify the index of a given add-node.
:param add_node_family_list: List of pairing sets, each containing a distinct add-node and its immediate
parents and children.
:param add_node_to_find: Desired add-node.
:return: Index of desired add-node in list.
"""
add_node_list = [add_node_family.add_node for add_node_family in add_node_family_list]
return add_node_list.index(add_node_to_find)
[docs]
def get_add_connections(add_node_family_list: list[ResidualAddFamily], add_node_idx: int) \
-> tuple[list[AwesomeNode], list[AwesomeNode]]:
"""
For a given add-node, scan its parents, adding all convs to the list. If an add is detected, recurse
to that add-node. If neither is detected, the grouping is invalid, so 'None' is appended to the list.
Do this for the children of the add-node as well.
:param add_node_family_list: List of pairing sets, each containing a distinct add-node and its immediate
parents and children.
:param add_node_idx: Index of the add-node we're currently scanning.
:return: List of parents and children for a residual grouping.
"""
add_node_family_list[add_node_idx].visited = True
final_parents = []
final_children = []
# Iterate through parents. Append to parents if conv, recurse if add,
# append 'None' if node is an invalid operation.
for parent_node in add_node_family_list[add_node_idx].parents:
if parent_node is None:
final_parents.append(None)
elif isinstance(parent_node.ir.operation, ConvAddActivationOp):
final_parents.append(parent_node)
elif isinstance(parent_node.ir.operation, AddActivationOp):
parent_add_index = get_new_add_index(add_node_family_list, parent_node)
if not add_node_family_list[parent_add_index].visited:
sub_parents, sub_children = get_add_connections(add_node_family_list, parent_add_index)
final_parents = final_parents + sub_parents
final_children = final_children + sub_children
else:
raise sima_logger.UserFacingException(f"Unexpected node {parent_node.name} \
encountered in residual pairing search.")
# Iterate through children. Append to children if conv, recurse if add,
# append 'None' if node is an invalid operation.
for children_node in add_node_family_list[add_node_idx].children:
if children_node is None:
final_children.append(None)
elif isinstance(children_node.ir.operation, ConvAddActivationOp):
final_children.append(children_node)
elif isinstance(children_node.ir.operation, AddActivationOp):
child_add_index = get_new_add_index(add_node_family_list, children_node)
if not add_node_family_list[child_add_index].visited:
sub_parents, sub_children = get_add_connections(add_node_family_list, child_add_index)
final_parents = final_parents + sub_parents
final_children = final_children + sub_children
else:
raise sima_logger.UserFacingException(f"Unexpected node {children_node.name} \
encountered in residual pairing search.")
return final_parents, final_children
[docs]
def get_equalization_scale(parent_node: AwesomeNode) -> tuple[np.ndarray, float, float]:
"""
Compute a channel-wise scale given the node pair's parent node
weights and output activation extrema. Based on the Same, Same,
but Different paper:
https://proceedings.mlr.press/v97/meller19a/meller19a.pdf.
:param parent_node: The first node in the node-pair.
:return: Channel-wise scales, maximum activation value, and new max after scaling.
"""
act_mins_channel, act_maxes_channel = parent_node.ir.calib_attrs.observer.min_max()
act_magnitudes = np.maximum(np.abs(act_mins_channel), np.abs(act_maxes_channel))
max_val = np.max(np.abs(act_maxes_channel))
# Make sure scale never falls below 1.
act_scales_og = np.maximum(max_val / (act_magnitudes + EPSILON), EQUALIZATION_MINVAL)
# The original paper picks minimum of weight scale and activation scale.
# As we have per-channel quantization for weight, we don't need to consider weight scales.
scales_og = np.minimum(act_scales_og, EQUALIZATION_MAXVAL)
# Get rescaled maxes
scaled_max = np.max(act_maxes_channel * scales_og)
return scales_og, max_val, scaled_max
[docs]
def parent_node_update(parent_node: AwesomeNode, scales: np.ndarray):
"""
Update a parent node's weights and bias given its scales.
:param parent_node: AwesomeNode to update
:param scales: Numpy array of scales to update along channels of parent node weight and bias.
:return: None.
"""
# Reshape scales to match the (G, O) axes of parent node's weights. Implicitly broadcast in the other dimensions.
# Rescale the parent node.
scale_shape = parent_node.ir.attrs.weights_attrs.data.shape[-2:]
parent_node.ir.attrs.weights_attrs.data *= scales.reshape(scale_shape)
if parent_node.ir.attrs.bias_attrs:
parent_node.ir.attrs.bias_attrs.data *= scales.flatten()
[docs]
def child_node_update(child_node: AwesomeNode, scales: np.ndarray):
"""
Update a child node's weights given its scales.
:param child_node: AwesomeNode to update
:param scales: Numpy array of scales to update along channels of parent node weight and bias.
:return: None.
"""
conv_attrs = child_node.ir.attrs.conv_attrs
groups = conv_attrs.groups
input_channels = conv_attrs.weight_shape[-3]
# Reshape the scales to match the weight tensor's (I, G) channels.
scales = scales.reshape((groups, input_channels)).transpose()
# Add the weight tensor's O channel so that it can be broadcast to the weight tensor.
scales = np.expand_dims(scales, axis=2)
child_node.ir.attrs.weights_attrs.data /= scales
[docs]
def scale_concat_children(concat_node_pairing_list: list[PairingSet], parent_scale_dict: dict[AwesomeNode, np.ndarray]):
"""
Iterate through the list of identified concat node dicts. Whichever nodes are marked as 'to-do',
concat the corresponding parent node scales, and rescale the children node weights. Mark the concat node's status
as 'done' once children are scaled.
:param concat_node_pairing_list: List of dicts of AwesomeNode parents and children that
are connected by a concat node.
:param parents_scale_dict: Dictionary of parent node names and scales
:return: None
"""
for pairing in concat_node_pairing_list:
if pairing.status == ScaleStatus.TOSCALE:
# Concatenate scales of parent_scales
scales = np.concatenate([parent_scale_dict[parent_node.name] for parent_node in pairing.parents])
for child_node in pairing.children:
child_node_update(child_node, scales)
pairing.status = ScaleStatus.DONE # Mark as done
[docs]
def scale_concat_parents(concat_node_pairing: PairingSet) -> dict[AwesomeNode, np.ndarray]:
"""
Iterate through a set of parent nodes. First find the original scales, then rescale so that all the max values are
the same. Create a dictionary of the parent node names and their corresponding scales to scale the children nodes.
:param concat_node_pairing: Pairing set for a given concat node, containing a list of parents, children, and status.
We only use the parents list here.
:return: Dictionary of parent node names and corresponding scales as numpy arrays.
"""
parent_scales = {} # Dict of parent_node: parent_scale list
parent_scaled_maxes = {} # Dict of parent_node: parent_scale_max float
largest_max_val = 0
# Get scale and max values
for parent_node in concat_node_pairing.parents:
parent_scales[parent_node.name], max_val, scaled_max = get_equalization_scale(parent_node)
parent_scaled_maxes[parent_node.name] = scaled_max
if max_val > largest_max_val:
largest_max_val = max_val
# Modify scale by a multiplicative factor such that the maximum activation is the same across all
# concatenated channels
for parent_node in concat_node_pairing.parents:
parent_node_name = parent_node.name
old_parent_scale = parent_scales[parent_node_name]
parent_scales[parent_node_name] = old_parent_scale * largest_max_val \
/ (parent_scaled_maxes[parent_node_name] + EPSILON)
parent_node_update(parent_node, parent_scales[parent_node_name])
return parent_scales
def _is_subset(sub_list: list[AwesomeNode], super_list: list[AwesomeNode]) -> bool:
"""
Function to return if a list is a subset of another list.
:param sub_list: initial set of elements
:param super_list: potential superset
:return: True if super_list is a superset of the sub_list
"""
return all(node in super_list for node in sub_list)
[docs]
def find_largest_parent_set(concat_node_pairing_list: list[PairingSet], parents_list: list[AwesomeNode], \
idx_parents_list: int) -> int:
"""
Iterate through the list of identified concat node pairing sets, and find the largest set of parents containing an
initial subset of parent nodes. Update any concat set status as 'to-do' if the original set of parents are present.
This function operates on the key asssumption that all parent sets are subsets of larger parent sets.
:param concat_node_pairing_list: List of pairing sets of AwesomeNode parents and children
that are connected by a concat node.
:param parents_list: List of the initial set of parents for which we are trying to find supersets.
:param idx_parents_list: Index of parents_list in the full concat_node_pairing_list
:return: Index of largest set of parents
"""
max_parent_set_size = len(parents_list)
index_max_parents = idx_parents_list
for i, pairing in enumerate(concat_node_pairing_list):
if _is_subset(parents_list, pairing.parents):
pairing.status = ScaleStatus.TOSCALE
num_pairing_parents = len(pairing.parents)
if num_pairing_parents > max_parent_set_size:
max_parent_set_size = num_pairing_parents
index_max_parents = i
return index_max_parents
[docs]
def find_residual_scale(parent_nodes: list[AwesomeNode]) -> np.ndarray:
"""
Iterate through the list of parent nodes and find the minimum of normalized scales across all
parents.
:param parent_nodes: List of AwesomeNode parents.
:return: Residual scales.
"""
residual_scales = None
# Compute min scale so none of the activation maxes are changed
for parent_node in parent_nodes:
scales, _, _ = get_equalization_scale(parent_node)
if residual_scales is None:
residual_scales = scales.copy()
else:
norm_residual_scales = residual_scales / np.min(residual_scales)
norm_node_scales = scales / np.min(scales)
residual_scales = np.maximum(np.minimum(norm_residual_scales, norm_node_scales), EQUALIZATION_MINVAL)
return residual_scales
[docs]
def pairings_update_pass(pairings_lists: tuple[list[PairingSet], list[PairingSet], list[PairingSet]]):
"""
Iterate through the lists of identified single-conv-parent pairings, concat pairings, and residual-add pairings,
and scale them channel-wise based on channel_equalization.
For conv_node_pairing_list, there can only be one parent, but the parent may have multiple children. The scale is
found from the parent node, and is used to inversely scale the children nodes.
For concat_node_pairing_list, iterate through the list of identified concat node pairing sets. For a given concat,
the algorithm finds the scales for all its parents. It then rescales the scales to have the same max value.
It then scales the parent nodes' weights, concats the scales together, and inversely scales the children nodes'
weights. The algorithm makes the key assumption that for a parent set of a concat node, any other concat node
that shares one of the parents will share all the parents (i.e., concats are structured in hierarchical manner,
where parent sets are subsets of larger parent sets). This assumption is also made by the Same, Same, but Different
paper's channel equalization implementation: https://github.com/icml2019/equalization.
For residual_node_pairing_list, iterate through the list of identified residual node-pairs and scale them
channel-wise. Do this by first finding the minimum of normalized scales across all parents.
Then scale all parents and children by the same scale.
:param pairings_lists: A tuple containing three lists of node pairings to scale: conv_node_pairing_list,
concat_node_pairing_list, residual_node_pairing_list
:return: Updates model weights in-place, returns nothing.
"""
conv_node_pairing_list, concat_node_pairing_list, residual_node_pairing_list = pairings_lists
# Single Conv parent node pairings update pass
for node_pairing in conv_node_pairing_list:
parent_nodes = node_pairing.parents
assert len(parent_nodes) == 1, f"Expect only one parent node but got {len(parent_nodes)}"
children_nodes = node_pairing.children
parent_node = parent_nodes[0]
scales, _, _ = get_equalization_scale(parent_node)
parent_node_update(parent_node, scales)
for child_node in children_nodes:
child_node_update(child_node, scales)
node_pairing.status = ScaleStatus.DONE
# Concat node pairings update pass. We scale all concats for a given parent at once.
for i, concat_node_pairing in enumerate(concat_node_pairing_list):
# Make sure no 'to-scale's are present outside the scale functions. Status should be 'unscaled' or 'done'
assert concat_node_pairing.status != ScaleStatus.TOSCALE
# If status of pair is not marked as done or to do, go in and find largest set of parents, scale accordingly
if concat_node_pairing.status == ScaleStatus.UNSCALED:
parents_list = concat_node_pairing.parents
# Find the largest set of parents and mark them as to-do, so we can compute the scales once
largest_parent_set_idx = find_largest_parent_set(concat_node_pairing_list, parents_list, i)
# Get parent scales, rescale, and update parent weights. Return a dict of the parent node names and scales
parent_scale_dict = scale_concat_parents(concat_node_pairing_list[largest_parent_set_idx])
# Scale all children nodes accordingly, and update the status to 'done'
scale_concat_children(concat_node_pairing_list, parent_scale_dict)
# Residual add node pairings update pass
for residual_node_pairing in residual_node_pairing_list:
parent_nodes = residual_node_pairing.parents
children_nodes = residual_node_pairing.children
residual_scales = find_residual_scale(parent_nodes)
# Scale parents by the computed residual scale
for parent_node in parent_nodes:
parent_node_update(parent_node, residual_scales)
# Scale children by the computed residual scale
for child_node in children_nodes:
child_node_update(child_node, residual_scales)
residual_node_pairing.status = ScaleStatus.DONE