#########################################################
# Copyright (C) 2022 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Ljubomir Papuga
#########################################################
import math
import numpy as np
from collections import defaultdict
from dataclasses import replace
import afe.ir.operation_functions as op_fn
from afe.apis.transform import (
detessellation_transform, pack_transform, slice_transform, tessellation_transform,
unpack_transform
)
from afe.backends import BackendIR
from afe.backends.mla.afe_to_n2a_compiler.defines import PackParameters, TensorTessellateParameters
from afe.ir.attributes import (
DetessellationTransformAttrs, UnpackTransformAttrs, QuantResultTensorType
)
from afe.ir.build_node import create_tuple_get_item_nodes
from afe.ir.defines import (
InputName, NodeName, TensorValue, TupleValue, data_value_elements, get_expected_tensor_value
)
from afe.ir.net import AwesomeNet, rename_mut_awesomenet, Renaming
from afe.ir.node import AwesomeNode, node_is_tuple_get_item
from afe.ir.sima_ir import SiMaIR
from afe.ir.tensor_type import NodeType, TensorType, ScalarType
from sima_utils.logging.sima_logger import sima_log_dbg
def _round_up_to(value: int, multiple: int) -> int:
return ((value + multiple - 1) // multiple) * multiple
def _generate_tessellate_node(
input_name: InputName, input_node_name: NodeName, tessellate_param: TensorTessellateParameters,
input_type: TensorType
) -> AwesomeNode:
"""
Generate the AwesomeNode containing the TessellationTransformOp. Needed to
transform the input to the AwesomeNode containing the BackendIR which is being
executed on MLA. The single input is transformed from 4D NHWC representation to blocked HWC
representation.
:param tessellate_param: Tessellation parameters that contains the tile shape for each blocked.
:param input_type: Input type to tessellate node that is being created.
Currently needs to be provided by the caller, as there is no sufficient infrastructure
to obtain the input type from MLA node.
:return: AwesomeNode containing the TessellationTransformOp.
"""
if len(input_type.shape) == 4:
assert tessellate_param.tile_shape[0] == 1
slice_shape = tessellate_param.tile_shape[1:]
else:
slice_shape = tessellate_param.tile_shape
tessellate_node_name = f"tessellate_{input_node_name}_{input_name}"
_, tessellate_node = tessellation_transform(
slice_shape=slice_shape,
align_c16=tessellate_param.dram_layout.align_c16,
cblock=tessellate_param.dram_layout.cblock
).extract_ir(input_type, input_node_name, tessellate_node_name)
assert len(tessellate_node) == 1
tessellate_node = tessellate_node[0]
sima_log_dbg(f"Inserted tessellate node: {tessellate_node_name}")
return tessellate_node
def _generate_pack_node(
input_types: list[TensorType], input_node_names: list[NodeName], pack_node_name: NodeName
) -> AwesomeNode:
"""
Generate AwesomeNode containing the PackTransformOp. Needed to pack multiple inputs
of the MLA node into a single buffer.
:param input_types: Tensor types of the pack input nodes.
:param input_node_names: Names of the pack input nodes.
:param pack_node_name: Name of the pack node.
:return: AwesomeNode containing the PackTransformOp.
"""
_, pack_nodes = pack_transform().extract_pack_ir(input_types, input_node_names, pack_node_name)
assert len(pack_nodes) == 1
pack_node = pack_nodes[0]
return pack_node
[docs]
def insert_pre_mla_segment_nodes(
net: AwesomeNet, mla_node: AwesomeNode, pack_params: PackParameters
):
"""
Insert nodes into the AwesomeNet in place prior to the AwesomeNode containing the
BackendIR which is to be executed on MLA. Each input to the MLA node should be
tessellated in accordance with provided TessellateParameters. If enable_mla flag is set to True,
tessellation will be performed on the MLA, otherwise tessellation node will be inserted.
In case of multiple inputs to the MLA node, the (tessellated) inputs may be packed into a single
or multiple section buffers.
:param net: Top-level AwesomeNet.
:param mla_node: AwesomeNode containing the MLA sub-graph.
:param pack_params: A dictionary containing input tessellation parameters for
each section buffer to the MLA subgraph. Dictionary keys are the section names. Values in
the dictionary are tuples of input ids and corresponding tessellation params. The input ids
are the indices in the MLA input list before the backend compilation. The order of
dictionary items is the order of MLA arguments after the backend compilation.
:return: None. Mutates the AwesomeNet by inserting the tessellation and if needed,
pack nodes. The AwesomeNet is in invalid state on the return. The topological_sort
is needed to return the AwesomeNet to valid state.
"""
input_types = [get_expected_tensor_value(net.nodes[input_node_name].get_type().output)
for input_node_name in mla_node.input_node_names]
new_mla_input_nodes = list()
for section_base_name, pack_input_list in pack_params.items():
# Collect the node names for pack node and create tessellate node if needed.
pack_input_node_names = list()
for input_id, tessellate_param in pack_input_list:
if tessellate_param.enable_mla:
pack_input_node_names.append(mla_node.input_node_names[input_id])
else:
tessellate_node = _generate_tessellate_node(
mla_node.input_names[input_id], mla_node.input_node_names[input_id],
tessellate_param, input_types[input_id]
)
net.nodes.update({tessellate_node.name: tessellate_node})
pack_input_node_names.append(tessellate_node.name)
# Create pack node if there are multiple inputs to this pack.
if len(pack_input_list) > 1:
pack_node_name = NodeName(f"{mla_node.name}_{section_base_name.replace('.', '_')}")
pack_input_types = [
get_expected_tensor_value(net.nodes[input_node_name].get_type().output)
for input_node_name in pack_input_node_names
]
pack_node = _generate_pack_node(pack_input_types, pack_input_node_names, pack_node_name)
net.nodes.update({pack_node.name: pack_node})
new_mla_input_nodes.append(pack_node)
else:
new_mla_input_nodes.append(net.nodes[pack_input_node_names[0]])
# Update the input node names and input type.
mla_node.input_node_names = list(node.name for node in new_mla_input_nodes)
mla_node.ir.pack_parameters = pack_params
mla_node.ir.type = NodeType(
{
node.name: TensorValue(get_expected_tensor_value(node.get_type().output))
for node in new_mla_input_nodes
},
mla_node.ir.get_type().output
)
def _update_node_calib_attrs_from_backend_quant_attrs(node: AwesomeNode,
tensor_types: list[TensorType]):
tensor_values = [TensorValue(QuantResultTensorType(type=tensor_type, quant=None, requant_method=None))
for tensor_type in tensor_types]
if len(tensor_values) > 1:
quant_result = TupleValue(tensor_values)
else:
quant_result = tensor_values[0]
node.ir.calib_attrs.quant = quant_result
def _generate_unpack_node(
batch_size: int, input_size: int, output_types: list[TensorType], input_node_name: NodeName,
unpack_node_name: NodeName
) -> AwesomeNode:
"""
Generate AwesomeNode containing the UnpackTransformOp. Needed to unpack multiple outputs
of the MLA node into multiple buffers.
:param batch_size: Batch size.
:param input_size: Number of bytes to be unpacked for each batch.
:param output_types: Tensor types of the unpacked outputs.
:input_node_name: Input node name of the unpack node.
:unpack_node_name: Name of the unpack node.
:return: AwesomeNode containing the UnpackTransformOp.
"""
input_shape = (batch_size, input_size)
_, unpack_nodes = (
unpack_transform(tensor_types=output_types).extract_unpack_ir(
TensorType(ScalarType.int8, input_shape), input_node_name, unpack_node_name
)
)
assert len(unpack_nodes) == 1
unpack_node = unpack_nodes[0]
unpack_attrs = unpack_node.ir.attrs
assert isinstance(unpack_attrs, UnpackTransformAttrs)
tensor_types = unpack_attrs.tensor_types
_update_node_calib_attrs_from_backend_quant_attrs(unpack_node, tensor_types)
sima_log_dbg(f"Inserted unpack node: {unpack_node.name}")
return unpack_node
def _generate_detessellate_node(
input_node: AwesomeNode, detessellate_param: TensorTessellateParameters, output_type: TensorType
):
"""
Generate the AwesomeNode containing the DetessellationTransformOp.
:param input_node: AwesomeNode which output should be detessellated. In case of a multiple outputs
input node is the MLA node, otherwise input node will be TupleGetItem node.
:param detessellation_parameters: Detessellation parameters dictionary.
:param output_type: DataValue containing information on the TensorType that should be
produced using detessellation node.
:return: AwesomeNode containing the DetessellationTransformOp.
"""
detessellate_node_name = f"detessellate_{input_node.name}"
# Calculate block sizes
if len(output_type.shape) == 4:
assert detessellate_param.tile_shape[0] == 1
slice_shape = detessellate_param.tile_shape[1:]
else:
slice_shape = detessellate_param.tile_shape
input_shape = op_fn.calculate_tessellated_tensor_shape(
output_type, slice_shape, detessellate_param.dram_layout.align_c16
)
_, detessellate_nodes = detessellation_transform(
slice_shape=slice_shape,
align_c16=detessellate_param.dram_layout.align_c16,
cblock=detessellate_param.dram_layout.cblock,
frame_type=output_type
).extract_ir(TensorType(ScalarType.int8, input_shape), input_node.name, detessellate_node_name)
assert len(detessellate_nodes) == 1
detessellate_node = detessellate_nodes[0]
detessellate_attrs = detessellate_node.ir.attrs
assert isinstance(detessellate_attrs, DetessellationTransformAttrs)
tensor_type = detessellate_attrs.frame_type
if isinstance(input_node.ir, BackendIR):
_update_node_calib_attrs_from_backend_quant_attrs(detessellate_node, [tensor_type])
else:
assert isinstance(input_node.ir, SiMaIR)
assert isinstance(input_node.ir.calib_attrs.quant, TensorValue)
detessellate_node.ir.calib_attrs.quant = TensorValue(
QuantResultTensorType(tensor_type, None, None)
)
sima_log_dbg(f"Inserted detessellate node: {detessellate_node_name}")
return detessellate_node
def _generate_slice_node(
input_node: AwesomeNode, output_type: TensorType, mla_output_shape: tuple[int, ...]
) -> AwesomeNode:
"""
Generate the AwesomeNode containing the SliceTransformOp.
:param input_node: AwesomeNode which output should be sliced. In case of a multiple outputs
input node is the MLA node, otherwise input node will be TupleGetItem node.
:param output_type: DataValue containing information on the TensorType that should be
produced using slice node.
:param mla_output_shape: Output shape of the MLA node where the channel value will always be a multiple of 16.
:return: AwesomeNode containing the SliceTransformOp.
"""
slice_node_name = f"slice_{input_node.name}"
desired_shape = output_type.shape
# Generate Slice node
_, slice_nodes = (
slice_transform(
begin=[0] * len(desired_shape), end=list(desired_shape)
).extract_ir(
TensorType(output_type.scalar, mla_output_shape), input_node.name,
NodeName(slice_node_name)
)
)
assert len(slice_nodes) == 1
slice_node = slice_nodes[0]
tensor_type = TensorType(output_type.scalar, desired_shape)
if isinstance(input_node.ir, BackendIR):
_update_node_calib_attrs_from_backend_quant_attrs(slice_node, [tensor_type])
else:
assert isinstance(input_node.ir, SiMaIR)
assert isinstance(input_node.ir.calib_attrs.quant, TensorValue)
slice_node.ir.calib_attrs.quant = TensorValue(
QuantResultTensorType(tensor_type, None, None)
)
sima_log_dbg(f"Inserted slice node: {slice_node_name}")
return slice_node
def _generate_output_node(
mla_output_node: AwesomeNode, detessellate_param: TensorTessellateParameters,
output_type: TensorType
) -> AwesomeNode:
"""
Generate an output node of an MLA segment.
If the enable_mla flag in detessellation parameter set to False generate detessellation node,
if enable_mla flag is set to True generate slice node to remove padding from the MLA node if there is padding,
if there is no padding output of the network will be the output from the mla_output_node.
:param mla_output_node: An output of the MLA node.
:param detessellation_param: Detessellation parameter of the mla_output_node.
:param output_type: Data type of the mla_output_node.
:return: An new output node of the MLA segment.
"""
if detessellate_param.enable_mla:
desired_shape = output_type.shape
mla_output_shape = op_fn.get_channel_aligned_shape(
tensor_shape=output_type.shape,
elem_size=np.dtype(output_type.scalar.numpy_type()).itemsize
)
if mla_output_shape != desired_shape:
# The data has padding
output_node = _generate_slice_node(mla_output_node, output_type, mla_output_shape)
else:
# The data is already in the right format
output_node = mla_output_node
else:
# The data is tessellated
output_node = _generate_detessellate_node(mla_output_node, detessellate_param, output_type)
return output_node
def _get_mla_segment_output_tgi_nodes(net: AwesomeNet, mla_node: AwesomeNode) -> list[AwesomeNode]:
"""
Helper function returning the TupleGetItem nodes that collect outputs of the MLA node,
in case MLA node has multiple outputs.
:param net: Top-level AwesomeNet representation of the model.
:param mla_node: An AwesomeNode containing the BackendIR which is executed on MLA.
:return: List of AwesomeNodes containing TupleGetItemOp which inputs are the outputs
of the MLA AwesomeNode.
"""
tgi_nodes: list[AwesomeNode] = list()
for out_node in net.nodes.values():
if mla_node.name in out_node.input_node_names:
assert node_is_tuple_get_item(out_node)
tgi_nodes.append(out_node)
tgi_nodes.sort(key=lambda n: n.ir.attrs.index)
return tgi_nodes
def _update_post_mla_segment_tuple_get_item_nodes(
unpack_node: AwesomeNode, tgi_nodes: list[AwesomeNode],
unpack_output_list: list[tuple[int, TensorTessellateParameters]]
):
"""
Updates the TupleGetItem nodes so that their index is in order in which their values are
unpacked from an MLA output. This order corresponds to the order in which detessellation
parameters are given.
:param unpack_node: AwesomeNode containing the UnpackTransformOp, which unpacks the outputs
of the MLA node into multiple buffers.
:param tgi_nodes: List of AwesomeNodes containing the TupleGetItemOp, collecting the
multiple outputs of the MLA AwesomeNode.
:param unpack_output_list: List of output ids and detessellation params for unpack node outputs.
The output ids are the indices in the output list before the backend compilation. The order
of the unpack_output_list is the order of the unpack node's outputs.
:return: None. Mutates the TupleGetItem nodes to match the order in which the outputs are
unpacked from MLA node.
"""
unpack_attrs = unpack_node.ir.attrs
assert isinstance(unpack_attrs, UnpackTransformAttrs)
unpack_output_types = unpack_attrs.tensor_types
unpack_quant_res_type = unpack_node.ir.calib_attrs.quant
assert isinstance(unpack_quant_res_type, TupleValue)
for unpack_output_id, (mla_output_id, detessellate_param) in enumerate(unpack_output_list):
assert mla_output_id < len(tgi_nodes)
tgi_node = tgi_nodes[mla_output_id]
tgi_node.ir.calib_attrs.quant = unpack_quant_res_type.elements[unpack_output_id]
tgi_node.ir.attrs.input_types = unpack_output_types
tgi_node.ir.attrs.index = unpack_output_id
tgi_node.input_node_names = [unpack_node.name]
[docs]
def insert_post_mla_segment_nodes(
net: AwesomeNet, mla_node: AwesomeNode, unpack_params: PackParameters,
uncompiled_nodes: list[NodeName]
):
"""
Insert nodes into the AwesomeNet in place after the AwesomeNode containing the
BackendIR which is to be executed on MLA.
In case of multiple output sections from the MLA node, new TupleGetItem AwesomeNodes are
generated to collect the outputs from the MLA node.
In case of multiple outputs in an output section, the section outputs are unpacked into multiple
buffers.
Finally, a detessellate node is inserted if the detessellation of an output is not done on MLA;
otherwise, a slice node is inserted if there is padding to be removed.
:param net: Top-level AwesomeNet.
:param mla_node: AwesomeNode containing the MLA sub-graph.
:param unpack_params: A dictionary containing output detessellation parameters for
each section buffer from the MLA subgraph. Dictionary keys are the section names. Values in
the dictionary are tuples of output ids and corresponding tessellation params. The output
ids are the indices in the MLA output list before the backend compilation. The order of
dictionary items is the order of MLA outputs after the backend compilation.
:param uncompiled_nodes: The list of nodes that have not been visited during compilation.
The list is mutated so that potential TupleGetItem nodes are removed from the list
while unpack nodes are generated.
:return: None. Mutates the AwesomeNet by inserting the detessellation nodes and unpack
node, if needed. The AwesomeNet is in invalid state on the return. The
topological_sort is needed to return the AwesomeNet to valid state.
"""
assert isinstance(mla_node.ir, BackendIR)
# Find the current mla output nodes before inserting new nodes.
if len(mla_node.ir.graph.outputs) > 1:
mla_output_nodes = _get_mla_segment_output_tgi_nodes(net, mla_node)
# Mark TGI nodes as compiled to avoid traversing them.
for tgi_node in mla_output_nodes:
uncompiled_nodes.remove(tgi_node.name)
else:
mla_output_nodes = [mla_node]
# Collect the output shapes and input size of the unpack nodes.
output_types = data_value_elements(mla_node.get_type().output)
batch_size = output_types[0].shape[0]
unpack_output_types_dict = defaultdict(list)
unpack_input_sizes_dict = defaultdict(int)
tgi_output_types = list()
for section_base_name, unpack_output_list in unpack_params.items():
for output_id, detessellate_param in unpack_output_list:
output_type = output_types[output_id]
if detessellate_param.enable_mla:
elem_size = np.dtype(output_type.scalar.numpy_type()).itemsize
output_shape = op_fn.get_channel_aligned_shape(
tensor_shape=output_type.shape, elem_size=elem_size
)
unpack_output_type = TensorType(output_type.scalar, output_shape)
unpack_input_size = math.prod(output_shape[1:], start=elem_size)
else:
tile_shape = detessellate_param.tile_shape
if len(output_type.shape) == 4:
assert tile_shape[0] == 1
tile_shape = tile_shape[1:]
output_shape = op_fn.calculate_tessellated_tensor_shape(
output_type, tile_shape, detessellate_param.dram_layout.align_c16
)
unpack_output_type = TensorType(ScalarType.int8, output_shape)
unpack_input_size = output_shape[1]
assert output_shape[0] == batch_size
unpack_output_types_dict[section_base_name].append(unpack_output_type)
unpack_input_sizes_dict[section_base_name] += unpack_input_size
# Collect the tensor types of the TGI nodes. If the TGI node is the input of a unpack node
# or a detessellate node, then set the tensor type to be number of bytes per batch.
# Otherwise, use the actual output data type with c16 aligned shape as the tensor type.
if len(unpack_output_list) > 1 or not detessellate_param.enable_mla:
tgi_output_type = TensorType(
ScalarType.int8, (batch_size, unpack_input_sizes_dict[section_base_name])
)
else:
tgi_output_type = TensorType(output_type.scalar, output_shape)
tgi_output_types.append(tgi_output_type)
# Insert TupleGetItem nodes if multiple unpack nodes need to be created.
if len(tgi_output_types) > 1:
input_quant = TupleValue(
[
TensorValue(
QuantResultTensorType(type=tensor_type, quant=None, requant_method=None)
) for tensor_type in tgi_output_types
]
)
unpack_input_nodes = create_tuple_get_item_nodes(
mla_node.name, tgi_output_types, f"{mla_node.name}_unpack", input_quant
)
node_type_output = TupleValue(
[TensorValue(tgi_output_type) for tgi_output_type in tgi_output_types]
)
net.nodes.update({node.name: node for node in unpack_input_nodes})
sima_log_dbg(f"Inserted TGI nodes: {[node.name for node in unpack_input_nodes]}")
else:
unpack_input_nodes = [mla_node]
node_type_output = TensorValue(tgi_output_types[0])
# Generate unpack nodes, reconnect tgi nodes to the new unpack nodes and generate output nodes.
output_input_nodes = mla_output_nodes.copy()
output_nodes = [None] * len(mla_node.ir.graph.outputs)
for (section_base_name, unpack_output_list), unpack_input_node, unpack_output_types in zip(
unpack_params.items(), unpack_input_nodes, unpack_output_types_dict.values()
):
# Generate unpack node if there are more than one output in this group.
if len(unpack_output_list) > 1:
input_size = unpack_input_sizes_dict[section_base_name]
unpack_node_name = NodeName(f"{mla_node.name}_{section_base_name.replace('.', '_')}")
unpack_node = _generate_unpack_node(
batch_size, input_size, unpack_output_types, unpack_input_node.name,
unpack_node_name
)
net.nodes.update({unpack_node.name: unpack_node})
if len(mla_output_nodes) > 1:
# Update TGI nodes. If no unpack node is added, then remove the corresponding tgi nodes.
if len(unpack_output_list) > 1:
_update_post_mla_segment_tuple_get_item_nodes(
unpack_node, mla_output_nodes, unpack_output_list
)
else:
output_id = unpack_output_list[0][0]
del net.nodes[mla_output_nodes[output_id].name]
output_input_nodes[output_id] = unpack_input_node
# Generate output nodes.
for output_id, detessellate_param in unpack_output_list:
output_nodes[output_id] = _generate_output_node(
output_input_nodes[output_id], detessellate_param, output_types[output_id]
)
rename_mut_awesomenet(
Renaming(
{
mla_output_node.name: output_node.name
for mla_output_node, output_node in zip(mla_output_nodes, output_nodes)
}
),
net
)
net.nodes.update({node.name: node for node in output_nodes})
mla_node.ir.unpack_parameters = unpack_params
mla_node.ir.type = NodeType(
mla_node.ir.get_type().inputs,
node_type_output
)