#########################################################
# Copyright (C) 2024 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Christopher Rodrigues
#########################################################
"""
This module defines `analyze_needs`, an analysis that determines
how every node's output will be requantized. For every node in
the graph (exceptions are listed below), it records what
requantization was performed for all uses of the node's output,
no matter where those requantizations occur in the graph.
Use without requantization is also recorded.
Additionally, it records the same information about the
inputs of a subnet.
The analysis proceeds as a backward pass over the graph.
Each time it finds a requantization node, it records how
the node's input tensor was requantized. That information
is propagated to the node that produced the tensor.
Requantize, Tuple, TupleGetItem, and Placeholder nodes
are not recorded. Instead, their output information is
transformed and propagated to their input.
"""
from typing import List, Tuple, Dict, Optional
import dataclasses
from afe.ir.defines import DataValue, TupleValue, TensorValue, map_data_value, NodeName
from afe.ir.net import AwesomeNet
from afe.ir.node import AwesomeNode, node_is_placeholder
from afe.ir.sima_ir import SiMaIR
import afe.ir.operations as afe_op
import afe.ir.attributes as afe_attr
from afe.ir.tensor_type import TensorType
from afe.ir.transform.requantization_hoisting.defines import Need, DataNeeds
class _RequantizationAnalysisResult:
"""
Result of requantization analysis.
During analysis, an instance holds the needs that have been computed for
each node and the placeholders of each subnetwork
that has been analyzed.
When analysis has completed, it holds information for all nodes and all subnetworks.
"""
# Need values that were computed for each node's output, including
# nodes in subnetworks
node_needs: Dict[NodeName, DataNeeds]
# Need values that were computed for the inputs of subnetworks.
# The list order is the same as the subnetwork's input node names.
subnet_input_needs: Dict[NodeName, List[DataNeeds]]
def __init__(self):
self.node_needs = {}
self.subnet_input_needs = {}
def get_node_needs(self, node: NodeName) -> Optional[DataNeeds]:
"""
Get the needs of a node's output.
Return None if nothing was found, which means that the analysis did not
find any use of the node's output.
"""
return self.node_needs.get(node)
def find_node_needs(self, node: NodeName) -> DataNeeds:
"""
Get the needs of a node's output.
"""
return self.node_needs[node]
def update_need(self, node: NodeName, need: DataNeeds):
"""
Record the need on the node's output.
It is combined with any previously recorded needs for the node.
"""
old = self.node_needs.get(node, None)
if old is not None:
need = join_needs(old, need)
self.node_needs[node] = need
def update_subnet_input_needs(self, node: NodeName, needs: List[DataNeeds]):
"""
Record the needs on the subnet's inputs.
"""
assert node not in self.subnet_input_needs
self.subnet_input_needs[node] = needs
def _not_requantized_need(ty: DataValue[TensorType]) -> DataNeeds:
"""
Create a DataNeeds describing a use of the given type
without any requantization.
:param ty: Type being used
:return: Needs
"""
return map_data_value(lambda _: [None], ty)
[docs]
def join_need_lists(xs: List[Need], ys: List[Need]) -> List[Need]:
"""
Concatenate two lists of Need and remove duplicate items.
Each input list is assumed not to contain duplicate items.
"""
ret = list(xs)
# Insert items from y, discard duplicates
for y in ys:
if not any(y == x for x in ret):
ret.append(y)
return ret
[docs]
def join_needs(x: DataNeeds, y: DataNeeds) -> DataNeeds:
match (x, y):
case (TensorValue(x1), TensorValue(y1)):
return TensorValue(join_need_lists(x1, y1))
case (TupleValue(xs), TupleValue(ys)):
assert len(xs) == len(ys)
return TupleValue([join_needs(xi, yi) for xi, yi in zip(xs, ys)])
case _:
raise ValueError("DataValues have different shapes")
[docs]
def update_need_dict(d: Dict[NodeName, DataNeeds], k: NodeName, v: DataNeeds):
old_v = d.get(k, None)
if old_v is not None:
v = join_needs(old_v, v)
d[k] = v
[docs]
def propagate_node_need(results: _RequantizationAnalysisResult, node: AwesomeNode):
"""
Determine the needs of an AwesomeNode.
:param results: Results of analysis. Holds previously collected results, which may be read.
Will be updated with results of analyzing node.
:param node: Node to analyze
"""
output_needs = results.find_node_needs(node.name)
if isinstance(node.ir, SiMaIR):
input_needs = sima_ir_needs(node.ir, output_needs)
elif isinstance(node.ir, AwesomeNet):
input_needs = net_needs(results, node.ir, output_needs)
results.update_subnet_input_needs(node.name, input_needs)
else:
raise ValueError("Unrecognized IR type")
assert len(input_needs) == len(node.input_node_names)
for input_node_name, n in zip(node.input_node_names, input_needs):
results.update_need(input_node_name, n)
[docs]
def sima_ir_needs(node: SiMaIR, output_need: DataNeeds) -> List[DataNeeds]:
"""
Determine the needs of a SiMaIR's input, given the needs of
its output.
:param node: Node to get the needs of.
:param output_need: Need of the output node.
:return: Needs of the node's inputs. The list has the same order as
the inputs in the node's type.
"""
attrs = node.quant_attrs if node.quant_attrs is not None else node.attrs
if isinstance(node.operation, afe_op.TupleGetItemOp):
# Propagate the output need to the selected tuple item
assert isinstance(attrs, afe_attr.TupleGetItemAttrs)
# Tuple is assumed not to contain tuples
needs_list: List[DataNeeds] = [TensorValue([])] * len(attrs.input_types)
needs_list[attrs.index] = output_need
return [TupleValue(needs_list)]
elif isinstance(node.operation, afe_op.TupleOp):
# Propagate the output need to all inputs
assert isinstance(attrs, afe_attr.TupleAttrs)
assert isinstance(output_need, TupleValue)
assert len(output_need.elements) == len(attrs.input_types)
return output_need.elements
elif isinstance(node.operation, afe_op.RequantizeOp):
assert isinstance(attrs, afe_attr.RequantizeQuantAttrs)
assert isinstance(output_need, TensorValue)
# The output of requantize should not be passed to another requantize
assert all(v is None for v in output_need.value)
return [TensorValue([attrs.requant])]
elif isinstance(node.operation, afe_op.PlaceholderOp):
raise TypeError("Unexpected placeholder")
# else, this operator is not handled specially. The node's inputs are used.
return [_not_requantized_need(t) for t in node.get_type().inputs.values()]
[docs]
def net_needs(results: _RequantizationAnalysisResult, net: AwesomeNet, output_needs: DataNeeds) -> List[DataNeeds]:
"""
Determine the needs of an AwesomeNet and of all nodes that it contains.
:param results: Results of analysis. Holds previously collected results, which may be read.
Will be updated with results of analyzing net.
:param net: Network to analyze
:param output_needs: Needs of the net's output
:return: The needs of the net's inputs.
"""
# Initialize needs of the output node
results.update_need(net.output_node_name, output_needs)
# Traverse network in reverse topological order
traversal_order = list(net.execution_order) # Make a copy for 'reverse' to modify
traversal_order.reverse()
assert traversal_order[0] == net.output_node_name
for node_name in traversal_order:
node = net.nodes[node_name]
if node_is_placeholder(node):
# Placeholder's input comes from outside this net
continue
propagate_node_need(results, node)
# Collect needs of the net's inputs
input_needs = []
for placeholder_name in net.input_node_names:
n = results.get_node_needs(placeholder_name)
if n is None:
# Placeholder is not used
assert isinstance(net.nodes[placeholder_name].get_type().output, TensorValue)
n = TensorValue([])
input_needs.append(n)
return input_needs
[docs]
def analyze_needs(net: AwesomeNet) -> Tuple[Dict[NodeName, DataNeeds], Dict[NodeName, List[DataNeeds]]]:
"""
Determine the needs of all nodes in a top-level AwesomeNet.
:param net: Network to analyze
:return: Needs on the outputs of all nodes in the network
"""
# All tensors in the model's output are used
output_type = net.nodes[net.output_node_name].get_type().output
output_needs = _not_requantized_need(output_type)
# Analyze the network
results = _RequantizationAnalysisResult()
net_needs(results, net, output_needs)
return results.node_needs, results.subnet_input_needs