Source code for afe.ir.transform.channel_scaling

#########################################################
# Copyright (C) 2024 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Praneeth Medepalli
#########################################################
import numpy as np
from dataclasses import dataclass
from sima_utils.logging import sima_logger
from afe.ir.defines import NodeName
from afe.ir.net import AwesomeNet
from afe.ir.utils import is_depthwise_conv, is_group_conv
from afe.ir.node import AwesomeNode, node_is_awesomenet, node_is_sima_ir
from afe.ir.operations import ConvAddActivationOp, TupleConcatenateOp, MaxPool2DOp, AvgPool2DOp, AddActivationOp
from enum import Enum

[docs] EPSILON = 1e-30 # To avoid division by zero errors
[docs] SMOOTHQ_MINVAL = 1e-2 # To limit excess scaling
[docs] EQUALIZATION_MINVAL = 1.0 # To limit excess scaling
[docs] EQUALIZATION_MAXVAL = 20
# Enum class to represent scaling status of a given concat pair
[docs] class ScaleStatus(Enum):
[docs] UNSCALED = 0
[docs] TOSCALE = 1
[docs] DONE = 2
@dataclass
[docs] class PairingSet: """ A parent-children pairing. It contains the identified parents, children, and a status on whether the set has been scaled. """
[docs] parents: list[AwesomeNode]
[docs] children: list[AwesomeNode]
[docs] status: ScaleStatus
@dataclass
[docs] class ResidualAddFamily: """ A collection of parents and children connected to an add-node. It contains the add-node, its parents and children, and a flag indicating if the node has been visited. """
[docs] add_node: AwesomeNode
[docs] parents: list[AwesomeNode]
[docs] children: list[AwesomeNode]
[docs] visited: bool = False
def _get_child_nodes(net: AwesomeNet, curr_node: NodeName) -> tuple[list[AwesomeNode], bool]: """ Find all the ConvAddActivation node and AddActivation node successors to the specified node. If an add node is detected as a child, update a flag indicating an add node is present. If a Max or Average Pooling operator is detected, recurse down to analyze its children. (As the pooling nodes are linear and separable across channels, their children can be treated as curr_node's children.) :param net: Net to analyze. :param curr_node: The current node whose children are analyzed. :return: List of children nodes to the current node, and a flag indicating the presence of an add node. """ children_nodes = [] add_node_present = 0 for node in net.nodes.values(): if curr_node in node.input_node_names: # If child node has multiple parents, store node only if its valid if len(node.input_node_names) == 1: # If child is of valid operation type, store in child node, else recurse further down # to find conv child node if isinstance(node.ir.operation, (MaxPool2DOp, AvgPool2DOp)): updated_children_nodes, updated_add_nodes = _get_child_nodes(net, node.name) children_nodes.extend(updated_children_nodes) add_node_present += updated_add_nodes elif isinstance(node.ir.operation, ConvAddActivationOp): children_nodes.append(node) else: children_nodes.append(None) # If child is anything but conv, store none else: if isinstance(node.ir.operation, AddActivationOp): add_node_present += 1 children_nodes.append(node) else: children_nodes.append(None) # If child is anything but an add-node, store none return children_nodes, add_node_present def _get_parent_nodes(net: AwesomeNet, curr_node: AwesomeNode) -> tuple[list[AwesomeNode], bool]: """ Find all the ConvAddActivation node and AddActivation node predecessors to the specified node. If an add node is detected as a parent, update a flag indicating an add node is present. If a Max or Average Pooling operator is detected, recurse up to analyze its parents. (As the pooling operators are linear and separable across channels, their parents can be treated as curr_node's parents.) :param net: Net to analyze. :param curr_node: The current node whose parents are analyzed. :return: List of parent nodes to the current node, and a flag indicating the presence of an add node. """ parent_nodes = [] add_node_present = 0 for input_node_name in curr_node.input_node_names: node = net.nodes[input_node_name] # If node is of valid operation type to recurse up the network, otherwise # add to parent list. If node is not valid, return None if isinstance(node.ir.operation, (MaxPool2DOp, AvgPool2DOp)): updated_parent_nodes, updated_add_nodes = _get_parent_nodes(net, node) parent_nodes.extend(updated_parent_nodes) add_node_present += updated_add_nodes elif isinstance(node.ir.operation, ConvAddActivationOp): parent_nodes.append(node) elif isinstance(node.ir.operation, AddActivationOp): add_node_present += 1 parent_nodes.append(node) else: parent_nodes.append(None) return parent_nodes, add_node_present
[docs] def get_pairing_lists(net: AwesomeNet) -> tuple[list[PairingSet], list[PairingSet], list[ResidualAddFamily]]: """ Given an AwesomeNet, generate valid parent-children pairings that fit various patterns. The patterns are: 1. single-parent conv nodes with exclusively conv children. 2. Concat nodes with exclusively conv parents and children. 3. Residual add connections with exclusively conv parents and children. :param net: AwesomeNet to iterate through. :return: Tuple of Lists corresponding to each pattern. Each element of a list contains a pairing sets of parents and children of a pattern, and status for patterns 1 & 2, and a visited flag for pattern 3. """ conv_node_pairing_list: list[PairingSet] = [] concat_node_pairing_list: list[PairingSet] = [] add_node_family_list: list[ResidualAddFamily] = [] # Topological sort to ensure pair pass is done in operation execution order net.topological_sort() for node_name in net.execution_order: node = net.nodes[node_name] if node_is_awesomenet(node): pairing_tuple = get_pairing_lists(node.ir) conv_node_pairing_list.extend(pairing_tuple[0]) concat_node_pairing_list.extend(pairing_tuple[1]) add_node_family_list.extend(pairing_tuple[2]) continue assert node_is_sima_ir(node) # If first node is an add op, find its immediate children and parents (that are not pooling). If there is # a non-conv node detected in the pairing set, add None to the list. Later, when residual sets are formed, # when a None is detected, the residual set is discarded from scaling. if isinstance(node.ir.operation, AddActivationOp): parent_nodes, _ = _get_parent_nodes(net, node) children_nodes, _ = _get_child_nodes(net, node.name) # Do a check for other branches of children conv nodes from parent convs: for parent in parent_nodes: if parent is not None and isinstance(parent.ir.operation, ConvAddActivationOp): # Will contain Nones if there are invalid children branch_children, _ = _get_child_nodes(net, parent.name) # Add all other children of conv node to list. Skip add-nodes to avoid infinite recursion. for b_node in branch_children: # b_node may be None, so we need to check that first if b_node is None or not isinstance(b_node.ir.operation, AddActivationOp): children_nodes.append(b_node) add_family = ResidualAddFamily(node, parent_nodes, children_nodes) add_node_family_list.append(add_family) # If first node is a concatenate op, look for its next immediate conv children and conv parents. If there is # a non-conv node detected in the pairing set, the entire pairing set is discarded. elif isinstance(node.ir.operation, TupleConcatenateOp): # If concatenation axis is not along channel axis, then don't create a pairing set if node.ir.attrs.concat_attrs.axis != 3: continue children_conv_nodes, num_add_nodes_children = _get_child_nodes(net, node_name) parent_conv_nodes, num_add_nodes_parents = _get_parent_nodes(net, node) add_node_present = (num_add_nodes_children > 0) or (num_add_nodes_parents > 0) if None in parent_conv_nodes or None in children_conv_nodes or add_node_present \ or len(children_conv_nodes) == 0 or len(parent_conv_nodes) == 0: continue pairing = PairingSet(parent_conv_nodes, children_conv_nodes, ScaleStatus.UNSCALED) concat_node_pairing_list.append(pairing) # If first node is a conv op, look for a potential pair. If there is # a non-conv node detected in the pairing set, the entire pairing set is discarded. elif isinstance(node.ir.operation, ConvAddActivationOp): parent_node = node children_conv_nodes, add_node_present = _get_child_nodes(net, node_name) if None in children_conv_nodes or add_node_present > 0 or len(children_conv_nodes) == 0: continue else: pairing = PairingSet([parent_node], children_conv_nodes, ScaleStatus.UNSCALED) conv_node_pairing_list.append(pairing) return conv_node_pairing_list, concat_node_pairing_list, add_node_family_list
[docs] def get_pairings(net: AwesomeNet) -> tuple[list[PairingSet], list[PairingSet], list[PairingSet]]: """ Control function to find parent-children pairings to scale, and do necessary post-processing to generate the final and complete set of pairings. The goal of this pairing framework is to assemble self-contained parent-children 'pairing sets' that need to be scaled, such that each pairing set can be scaled without affecting the input activations to the set or the output activations from the set branching to other nodes in the network. Note that a node can appear in multiple pairing sets, but is always scaled at most once as a parent and at most once as a child. Thus, this framework only guarantees that activations going in and coming out of a set are invariant to scaling within a set, but does not guarantee that weight changes are self-contained to a single set. :param net: Network that is traversed to find node-pairs. :return: Tuple of Lists of parent-children pairing fitting each of the above listed patterns. """ conv_node_pairing_list, concat_node_pairing_list, add_node_family_list = get_pairing_lists(net) # Search for add connections that connect together, and pool their parents and children. # Additionally, check that the residual sets are valid to scale. residual_node_pairing_list = residual_pair_pass(add_node_family_list) return conv_node_pairing_list, concat_node_pairing_list, residual_node_pairing_list
[docs] def residual_pair_pass(add_node_family_list: list[ResidualAddFamily]) -> list[PairingSet]: """ Iterate through the net to identify all the residual connection groupings, and identify all the parent and children nodes corresponding to each grouping. Each grouping will need to be scaled together with the same scale due to the design of the residual connection. A grouping is only valid if all its parents and children are convolution nodes (or pooling operators that lead to convolutions). Therefore, when a None is detected, the entire residual set is assembled and discarded from scaling. This algorithm is implemented by first identifying all the add operators in the network, and its immediate parents and children. It then scans each add-node's connections for further add-nodes, and recurses along the add-nodes connected together, pooling the corresponding parents and children together. After this, if even one parent/child is not a conv, the entire grouping is ineligible for scaling. :param add_node_family_list: List of pairing sets of add-nodes, its parents and children, and a visited flag for downstream use. :return: List of pairing sets, each containing the parents, children, and ScaleStatus of a residual grouping. """ final_node_pairing_list: list[PairingSet] = [] # Iterate through the list of add-nodes, and group valid residual connections together. for add_node_idx, add_node_family in enumerate(add_node_family_list): if not add_node_family.visited: final_parents_list, final_children_list = get_add_connections(add_node_family_list, add_node_idx) # Only add to final pairings if all the parents and children are valid if (None not in final_parents_list) and (None not in final_children_list) and \ len(final_parents_list) != 0 and len(final_children_list) != 0: # Check that all nodes are convs num_conv_parents = sum(isinstance(p.ir.operation, ConvAddActivationOp) for p in final_parents_list) num_conv_children = sum(isinstance(c.ir.operation, ConvAddActivationOp) for c in final_children_list) assert num_conv_parents == len(final_parents_list), 'Non-conv nodes detected in final parents list' assert num_conv_children == len(final_children_list), 'Non-conv nodes detected in final children list' pairing = PairingSet(final_parents_list, final_children_list, ScaleStatus.UNSCALED) final_node_pairing_list.append(pairing) return final_node_pairing_list
[docs] def get_new_add_index(add_node_family_list: list[ResidualAddFamily], add_node_to_find: AwesomeNode) -> int: """ Given a list of pairing sets containing add-nodes, identify the index of a given add-node. :param add_node_family_list: List of pairing sets, each containing a distinct add-node and its immediate parents and children. :param add_node_to_find: Desired add-node. :return: Index of desired add-node in list. """ add_node_list = [add_node_family.add_node for add_node_family in add_node_family_list] return add_node_list.index(add_node_to_find)
[docs] def get_add_connections(add_node_family_list: list[ResidualAddFamily], add_node_idx: int) \ -> tuple[list[AwesomeNode], list[AwesomeNode]]: """ For a given add-node, scan its parents, adding all convs to the list. If an add is detected, recurse to that add-node. If neither is detected, the grouping is invalid, so 'None' is appended to the list. Do this for the children of the add-node as well. :param add_node_family_list: List of pairing sets, each containing a distinct add-node and its immediate parents and children. :param add_node_idx: Index of the add-node we're currently scanning. :return: List of parents and children for a residual grouping. """ add_node_family_list[add_node_idx].visited = True final_parents = [] final_children = [] # Iterate through parents. Append to parents if conv, recurse if add, # append 'None' if node is an invalid operation. for parent_node in add_node_family_list[add_node_idx].parents: if parent_node is None: final_parents.append(None) elif isinstance(parent_node.ir.operation, ConvAddActivationOp): final_parents.append(parent_node) elif isinstance(parent_node.ir.operation, AddActivationOp): parent_add_index = get_new_add_index(add_node_family_list, parent_node) if not add_node_family_list[parent_add_index].visited: sub_parents, sub_children = get_add_connections(add_node_family_list, parent_add_index) final_parents = final_parents + sub_parents final_children = final_children + sub_children else: raise sima_logger.UserFacingException(f"Unexpected node {parent_node.name} \ encountered in residual pairing search.") # Iterate through children. Append to children if conv, recurse if add, # append 'None' if node is an invalid operation. for children_node in add_node_family_list[add_node_idx].children: if children_node is None: final_children.append(None) elif isinstance(children_node.ir.operation, ConvAddActivationOp): final_children.append(children_node) elif isinstance(children_node.ir.operation, AddActivationOp): child_add_index = get_new_add_index(add_node_family_list, children_node) if not add_node_family_list[child_add_index].visited: sub_parents, sub_children = get_add_connections(add_node_family_list, child_add_index) final_parents = final_parents + sub_parents final_children = final_children + sub_children else: raise sima_logger.UserFacingException(f"Unexpected node {children_node.name} \ encountered in residual pairing search.") return final_parents, final_children
[docs] def get_equalization_scale(parent_node: AwesomeNode) -> tuple[np.ndarray, float, float]: """ Compute a channel-wise scale given the node pair's parent node weights and output activation extrema. Based on the Same, Same, but Different paper: https://proceedings.mlr.press/v97/meller19a/meller19a.pdf. :param parent_node: The first node in the node-pair. :return: Channel-wise scales, maximum activation value, and new max after scaling. """ act_mins_channel, act_maxes_channel = parent_node.ir.calib_attrs.observer.min_max() act_magnitudes = np.maximum(np.abs(act_mins_channel), np.abs(act_maxes_channel)) max_val = np.max(np.abs(act_maxes_channel)) # Make sure scale never falls below 1. act_scales_og = np.maximum(max_val / (act_magnitudes + EPSILON), EQUALIZATION_MINVAL) # The original paper picks minimum of weight scale and activation scale. # As we have per-channel quantization for weight, we don't need to consider weight scales. scales_og = np.minimum(act_scales_og, EQUALIZATION_MAXVAL) # Get rescaled maxes scaled_max = np.max(act_maxes_channel * scales_og) return scales_og, max_val, scaled_max
[docs] def parent_node_update(parent_node: AwesomeNode, scales: np.ndarray): """ Update a parent node's weights and bias given its scales. :param parent_node: AwesomeNode to update :param scales: Numpy array of scales to update along channels of parent node weight and bias. :return: None. """ # Reshape scales to match the (G, O) axes of parent node's weights. Implicitly broadcast in the other dimensions. # Rescale the parent node. scale_shape = parent_node.ir.attrs.weights_attrs.data.shape[-2:] parent_node.ir.attrs.weights_attrs.data *= scales.reshape(scale_shape) if parent_node.ir.attrs.bias_attrs: parent_node.ir.attrs.bias_attrs.data *= scales.flatten()
[docs] def child_node_update(child_node: AwesomeNode, scales: np.ndarray): """ Update a child node's weights given its scales. :param child_node: AwesomeNode to update :param scales: Numpy array of scales to update along channels of parent node weight and bias. :return: None. """ conv_attrs = child_node.ir.attrs.conv_attrs groups = conv_attrs.groups input_channels = conv_attrs.weight_shape[-3] # Reshape the scales to match the weight tensor's (I, G) channels. scales = scales.reshape((groups, input_channels)).transpose() # Add the weight tensor's O channel so that it can be broadcast to the weight tensor. scales = np.expand_dims(scales, axis=2) child_node.ir.attrs.weights_attrs.data /= scales
[docs] def scale_concat_children(concat_node_pairing_list: list[PairingSet], parent_scale_dict: dict[AwesomeNode, np.ndarray]): """ Iterate through the list of identified concat node dicts. Whichever nodes are marked as 'to-do', concat the corresponding parent node scales, and rescale the children node weights. Mark the concat node's status as 'done' once children are scaled. :param concat_node_pairing_list: List of dicts of AwesomeNode parents and children that are connected by a concat node. :param parents_scale_dict: Dictionary of parent node names and scales :return: None """ for pairing in concat_node_pairing_list: if pairing.status == ScaleStatus.TOSCALE: # Concatenate scales of parent_scales scales = np.concatenate([parent_scale_dict[parent_node.name] for parent_node in pairing.parents]) for child_node in pairing.children: child_node_update(child_node, scales) pairing.status = ScaleStatus.DONE # Mark as done
[docs] def scale_concat_parents(concat_node_pairing: PairingSet) -> dict[AwesomeNode, np.ndarray]: """ Iterate through a set of parent nodes. First find the original scales, then rescale so that all the max values are the same. Create a dictionary of the parent node names and their corresponding scales to scale the children nodes. :param concat_node_pairing: Pairing set for a given concat node, containing a list of parents, children, and status. We only use the parents list here. :return: Dictionary of parent node names and corresponding scales as numpy arrays. """ parent_scales = {} # Dict of parent_node: parent_scale list parent_scaled_maxes = {} # Dict of parent_node: parent_scale_max float largest_max_val = 0 # Get scale and max values for parent_node in concat_node_pairing.parents: parent_scales[parent_node.name], max_val, scaled_max = get_equalization_scale(parent_node) parent_scaled_maxes[parent_node.name] = scaled_max if max_val > largest_max_val: largest_max_val = max_val # Modify scale by a multiplicative factor such that the maximum activation is the same across all # concatenated channels for parent_node in concat_node_pairing.parents: parent_node_name = parent_node.name old_parent_scale = parent_scales[parent_node_name] parent_scales[parent_node_name] = old_parent_scale * largest_max_val \ / (parent_scaled_maxes[parent_node_name] + EPSILON) parent_node_update(parent_node, parent_scales[parent_node_name]) return parent_scales
def _is_subset(sub_list: list[AwesomeNode], super_list: list[AwesomeNode]) -> bool: """ Function to return if a list is a subset of another list. :param sub_list: initial set of elements :param super_list: potential superset :return: True if super_list is a superset of the sub_list """ return all(node in super_list for node in sub_list)
[docs] def find_largest_parent_set(concat_node_pairing_list: list[PairingSet], parents_list: list[AwesomeNode], \ idx_parents_list: int) -> int: """ Iterate through the list of identified concat node pairing sets, and find the largest set of parents containing an initial subset of parent nodes. Update any concat set status as 'to-do' if the original set of parents are present. This function operates on the key asssumption that all parent sets are subsets of larger parent sets. :param concat_node_pairing_list: List of pairing sets of AwesomeNode parents and children that are connected by a concat node. :param parents_list: List of the initial set of parents for which we are trying to find supersets. :param idx_parents_list: Index of parents_list in the full concat_node_pairing_list :return: Index of largest set of parents """ max_parent_set_size = len(parents_list) index_max_parents = idx_parents_list for i, pairing in enumerate(concat_node_pairing_list): if _is_subset(parents_list, pairing.parents): pairing.status = ScaleStatus.TOSCALE num_pairing_parents = len(pairing.parents) if num_pairing_parents > max_parent_set_size: max_parent_set_size = num_pairing_parents index_max_parents = i return index_max_parents
[docs] def find_residual_scale(parent_nodes: list[AwesomeNode]) -> np.ndarray: """ Iterate through the list of parent nodes and find the minimum of normalized scales across all parents. :param parent_nodes: List of AwesomeNode parents. :return: Residual scales. """ residual_scales = None # Compute min scale so none of the activation maxes are changed for parent_node in parent_nodes: scales, _, _ = get_equalization_scale(parent_node) if residual_scales is None: residual_scales = scales.copy() else: norm_residual_scales = residual_scales / np.min(residual_scales) norm_node_scales = scales / np.min(scales) residual_scales = np.maximum(np.minimum(norm_residual_scales, norm_node_scales), EQUALIZATION_MINVAL) return residual_scales
[docs] def pairings_update_pass(pairings_lists: tuple[list[PairingSet], list[PairingSet], list[PairingSet]]): """ Iterate through the lists of identified single-conv-parent pairings, concat pairings, and residual-add pairings, and scale them channel-wise based on channel_equalization. For conv_node_pairing_list, there can only be one parent, but the parent may have multiple children. The scale is found from the parent node, and is used to inversely scale the children nodes. For concat_node_pairing_list, iterate through the list of identified concat node pairing sets. For a given concat, the algorithm finds the scales for all its parents. It then rescales the scales to have the same max value. It then scales the parent nodes' weights, concats the scales together, and inversely scales the children nodes' weights. The algorithm makes the key assumption that for a parent set of a concat node, any other concat node that shares one of the parents will share all the parents (i.e., concats are structured in hierarchical manner, where parent sets are subsets of larger parent sets). This assumption is also made by the Same, Same, but Different paper's channel equalization implementation: https://github.com/icml2019/equalization. For residual_node_pairing_list, iterate through the list of identified residual node-pairs and scale them channel-wise. Do this by first finding the minimum of normalized scales across all parents. Then scale all parents and children by the same scale. :param pairings_lists: A tuple containing three lists of node pairings to scale: conv_node_pairing_list, concat_node_pairing_list, residual_node_pairing_list :return: Updates model weights in-place, returns nothing. """ conv_node_pairing_list, concat_node_pairing_list, residual_node_pairing_list = pairings_lists # Single Conv parent node pairings update pass for node_pairing in conv_node_pairing_list: parent_nodes = node_pairing.parents assert len(parent_nodes) == 1, f"Expect only one parent node but got {len(parent_nodes)}" children_nodes = node_pairing.children parent_node = parent_nodes[0] scales, _, _ = get_equalization_scale(parent_node) parent_node_update(parent_node, scales) for child_node in children_nodes: child_node_update(child_node, scales) node_pairing.status = ScaleStatus.DONE # Concat node pairings update pass. We scale all concats for a given parent at once. for i, concat_node_pairing in enumerate(concat_node_pairing_list): # Make sure no 'to-scale's are present outside the scale functions. Status should be 'unscaled' or 'done' assert concat_node_pairing.status != ScaleStatus.TOSCALE # If status of pair is not marked as done or to do, go in and find largest set of parents, scale accordingly if concat_node_pairing.status == ScaleStatus.UNSCALED: parents_list = concat_node_pairing.parents # Find the largest set of parents and mark them as to-do, so we can compute the scales once largest_parent_set_idx = find_largest_parent_set(concat_node_pairing_list, parents_list, i) # Get parent scales, rescale, and update parent weights. Return a dict of the parent node names and scales parent_scale_dict = scale_concat_parents(concat_node_pairing_list[largest_parent_set_idx]) # Scale all children nodes accordingly, and update the status to 'done' scale_concat_children(concat_node_pairing_list, parent_scale_dict) # Residual add node pairings update pass for residual_node_pairing in residual_node_pairing_list: parent_nodes = residual_node_pairing.parents children_nodes = residual_node_pairing.children residual_scales = find_residual_scale(parent_nodes) # Scale parents by the computed residual scale for parent_node in parent_nodes: parent_node_update(parent_node, residual_scales) # Scale children by the computed residual scale for child_node in children_nodes: child_node_update(child_node, residual_scales) residual_node_pairing.status = ScaleStatus.DONE