#########################################################
# Copyright (C) 2020 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Joey Chou
#########################################################
from abc import abstractmethod
from collections import OrderedDict
import copy
from dataclasses import dataclass
from math import prod
from pathlib import Path
import numpy as np
from typing import Dict, Any, Optional, Type, List, Union, Set
from typing_extensions import assert_never
import afe.ir.quantization_utils
from afe._tvm._utils import convert_axis_to_non_negative
from afe.backends.mla.afe_to_n2a_compiler.n2a_backend_runner import N2ABackendSimulator, N2ACompiledBackendRunner, EvaluateTaskType
from afe.backends.mla.afe_to_n2a_compiler.n2a_compiler_utils import make_mlc_file_name, gen_compiler_rms_norm_vertex
from afe.backends import Backend, BackendIR
from afe.backends.mla.afe_to_n2a_compiler import n2a_compiler_utils
from afe.backends.mla.afe_to_n2a_compiler.defines import (
Activation, ArithFoldedRequantization, CompilerVertex, QuantizedWeightDtypes,
FractionalZeroRequantization, L2CachingMode, ModelGraph, RoundType, TFLiteRequantization,
TessellateParameters, TuplePlaceholderOperator, bfloat16, compute_bf16_exp_lut,
compute_bf16_reciprocal_lut, compute_bf16_rsqrt_lut, create_random_tensor, export_model_to_json,
generate_l1_based_model, get_id_requantization, get_max_batch_size_and_quadrants_used,
import_model_from_json, set_model_compilation_properties
)
from afe.backends.mla.afe_to_n2a_compiler.insert_nodes import (
insert_pre_mla_segment_nodes, insert_post_mla_segment_nodes
)
import afe.ir.attributes as attributes
from afe.ir.defines import (
InputName, AwesomeDataLayout, AwesomeDataLayout5D, NoneType, NodeName, Status, DataValue,
get_expected_tensor_value
)
from afe.ir.net import AwesomeNet, is_one_mla_segment_net, update_awesomenet_status
from afe.ir.node import AwesomeNode, node_is_awesomenet
import afe.ir.operations as operations
import afe.ir.operation_functions as op_fn
from afe.ir.quantization_utils import quantization_data_value_to_output_list
from afe.ir.tensor_type import TensorType, ScalarType
import afe.ir.utils as utils
from mlc.compiler.model_graph import PlaceholderName, PlaceholderValues
from mlc.test_util.test_context import CompilerConfig
from sima_utils.common import Platform
from sima_utils.logging.sima_logger import sima_log_dbg, sima_log_error, sima_log_info, UserFacingException
from ml_kernels.requantization import FloatRequantization, Renormalization
@dataclass
[docs]
class MLACompilerConfig:
"""
Parameters that control how to run the Production Compiler for a model.
Parameters for a specific subgraph or MLC file do not belong here. They should be passed
when calling the production compiler, instead.
:param tessellate_parameters: Dictionary containing information on tessellation parameters
for inputs and outputs of the MLA subgraph. For more information on how parameters are interpreted,
please see insert_nodes.py.
:param verify_mlc_files: Whether to verify correctness of the generated mlc files. Default is True.
It should be disabled only in test cases which also execute the MLA nodes using N2ABackendSimulator.
:param enable_large_tensors: If true, the MLA will handle large tensors, otherwise large tensors
will raise an exception.
:param l2_caching_mode: Parameters controlling L2 caching in compiler.
:param platform_type: Target MLA architecture type.
:param use_power_limits: If true, the compiler will schedule instructions to conform to power limits.
:param max_power: Set to a positive float value to override default max power when power limits are used.
:param compress: If true, the compiler will compress the data.
:param layer_norm_use_fp32_intermediates: Use FP32 intermediate tensors in BF16 LayerNorm kernel.
:param rms_norm_use_fp32_intermediates: Use FP32 intermediate tensors in BF16 RMSNorm kernel.
"""
[docs]
tessellate_parameters: Optional[TessellateParameters] = None
[docs]
verify_mlc_files: bool = True
[docs]
enable_large_tensors: bool = True
[docs]
l2_caching_mode: L2CachingMode = L2CachingMode.NONE
[docs]
use_power_limits: bool = False
[docs]
max_power: float | None = None
def _gen_modelgraph_vertex(node: AwesomeNode, vertices: Dict[str, CompilerVertex], data_layout: str,
rounding_type: str = RoundType.UPWARD) -> CompilerVertex:
"""
Generate CompilerVertex that is associated with specified AwesomeNode.
:param node: AwesomeNode that will be converted.
:param vertices: Vertices
:param data_layout: Data layout of the input.
:param rounding_type: Method of rounding.
:return: CompilerVertex
"""
# PlaceholderOp does not have inputs.
if isinstance(node.ir.operation, operations.PlaceholderOp):
input_vertices = {}
else:
input_vertices = OrderedDict({input_name: vertices[input_node_name] for input_name, input_node_name
in zip(node.input_names, node.input_node_names)})
afe_to_n2a_converter = get_afe_to_n2a_converter(node)
return afe_to_n2a_converter.gen_vertex(node.get_type().output, node.ir.attrs, node.ir.quant_attrs,
input_vertices, data_layout, node.name, rounding_type)
[docs]
def create_modelgraph(
net: AwesomeNet, rounding_type: RoundType = RoundType.UPWARD) -> ModelGraph:
"""
Convert a quantized AwesomeNet to a ModelGraph
:param net: AwesomeNet. A quantized AwesomeNet
:param rounding_type: str. Rounding method used in the model graph.
:return: A ModelGraph
"""
# Determine model layout based on input dimension
input_node_names = [node_name for node_name in net.input_node_names]
ndim = len(net.nodes[input_node_names[0]].ir.get_attrs().type.shape)
assert all(len(net.nodes[name].ir.get_attrs().type.shape) == ndim for name in input_node_names), \
f"Expect uniform 4D or 5D for Model Graph creation"
data_layout = AwesomeDataLayout if ndim == 4 else AwesomeDataLayout5D
vertices = OrderedDict()
for node_name in net._execution_order:
node = net.nodes[node_name]
# TODO: Remove data_layout from compiler interface once the ConvertLayout is fully implemented
vertices[node_name] = _gen_modelgraph_vertex(node, vertices, data_layout=data_layout,
rounding_type=rounding_type)
# Collect the input vertices for the ModelGraph.
input_vertices = [vertices[node_name] for node_name in net.input_node_names]
# Create output vertices for the ModelGraph. If the last vertex of the net is a Tuple vertex,
# use its inputs to create the output vertices; otherwise, use the last vertex to create the
# output vertex.
output_vertex = list(vertices.values())[-1]
if isinstance(output_vertex._operator, TuplePlaceholderOperator):
output_vertices = [
n2a_compiler_utils.gen_compiler_output_vertex(v) for v in output_vertex._inputs
]
else:
output_vertices = [n2a_compiler_utils.gen_compiler_output_vertex(output_vertex)]
return ModelGraph(input_vertices, output_vertices)
[docs]
def translate_sub_awesome_net_to_modelgraph(net: AwesomeNet, force_update_status: bool = False) -> Union[AwesomeNet, ModelGraph]:
"""
Take in an AwesomeNet and translate the MLA supported sub-graphs to ModelGraphs.
:param net: AwesomeNet
:return: An AwesomeNet that contains ModelGraph(s) or ModelGraph if there is no
sub-graphs that can be translated to MLA backend IR.
"""
# Try to convert the entire AwesomeNet to ModelGraph if the given AwesomeNet
# does not contain sub-graphs that can be translated to MLA backend IR.
if is_one_mla_segment_net(net):
return create_modelgraph(net)
for node_name in net.execution_order:
node = net.nodes[node_name]
if node_is_awesomenet(node):
# This function only supports MLA backend
assert node.ir.backend == Backend.MLA
# Use sub-graph's placeholders' input quantization parameters as
# the input quantization parameters of the backend graph
input_layer_bits = []
input_zps = []
input_scales = []
for name in node.ir.input_node_names:
input_quantization = \
attributes.get_data_value_quant_result_scale_with_dummy(node.ir.nodes[name].ir.calib_attrs.quant)
scale, zero_point, layer_bits, _, _ = quantization_data_value_to_output_list(input_quantization)
input_layer_bits.append(layer_bits)
input_zps.append(zero_point)
input_scales.append(scale)
net_type = node.get_type()
node.ir = BackendIR(create_modelgraph(node.ir), net_type, Backend.MLA)
update_awesomenet_status(net, Status.BACKEND_IR_LOWERED, force_update_status)
return net
[docs]
def create_backend_runner(net: AwesomeNet, mla_config: MLACompilerConfig, output_path: str,
batch_size: int, stage: int) -> N2ABackendSimulator:
"""
Verify mlc files using N2ABackendSimulator.
:param net: AwesomeNet.
:param node: AwesomeNode.
:param placeholder_values: Dictionary containing the untessellated inputs
of the MLA subgraph. Needed to generate reference output in order to produce the mlc_chk file.
:param output_path: Output path for .mlc files.
:param batch_size: Batch size for which the model has been compiled.
:param stage: Stage number of the graph.
"""
def _report_sim_failure(info: str) -> None:
raise RuntimeError(f"Simulation has detected a failure in {info}")
return N2ABackendSimulator(out_dir=output_path,
layout='NHWC',
model_name=net.name,
batch_size=batch_size,
platform_type=mla_config.platform_type,
report_sim_failure=_report_sim_failure,
evaluate_task_type=EvaluateTaskType.EXT_CMD)
def _create_ofm_chk_mlc(
backend_runner: N2ABackendSimulator,
node: AwesomeNode,
placeholder_values: PlaceholderValues,
tessellated_placeholder_values: dict[InputName, np.ndarray],
stage: int
):
backend_runner.evaluate_model_graph(
node.ir.graph, tessellated_placeholder_values, untessellated_inputs=placeholder_values,
stage=stage
)
[docs]
def get_awesomenet_max_batch_size(net: AwesomeNet,
desired_batch_size: int) -> Optional[int]:
"""
Traverses the nodes of an AwesomeNet and finds the maximal batch size
supported by compiler for all MLA subnets as a minimal value of maximal
batch sizes supported by compiler for each individual MLA subnet. If
there is no MLA subnets in an AwesomeNet, None value is returned.
:param net: An AwesomeNet which maximal supported batch size is to be
determined.
:param desired_batch_size: The AwesomeNet inputs' desired batch size to be used in compilation.
Compilation will choose the largest size that is supported for the entire AwesomeNet and that
is no larger than this.
:return: An integer value representing maximal supported batch size for an
AwesomeNet by compiler, not larger than the desired batch size, or None if no MLA segments are
present in an AwesomeNet.
"""
nodes_max_batch_size: Set[int] = set()
for node_name in net.execution_order:
node = net.nodes[node_name]
if isinstance(node.ir, BackendIR):
max_batch_size, _ = get_max_batch_size_and_quadrants_used(node.ir.graph,
desired_batch_size=desired_batch_size)
nodes_max_batch_size.add(max_batch_size)
return min(nodes_max_batch_size) if nodes_max_batch_size else None
def _get_placeholder_values(
net: AwesomeNet, node: AwesomeNode, batch_size: int
) -> tuple[PlaceholderValues, dict[InputName, np.ndarray]]:
"""
Generate placeholder values that will be used in _verify_mlc_files function.
:param net: AwesomeNet
:param node: AwesomeNode
:return: PlaceholderValues and tessellated placeholder values.
"""
input_types = [
net.nodes[input_node_name].get_type().output for input_node_name in node.input_node_names
]
input_data = list()
rng = np.random.default_rng(2)
for input_type in input_types:
in_type = get_expected_tensor_value(input_type)
shape = in_type.shape
# FIXME: Force to use bfloat16 if float32 is provided.
dtype = in_type.scalar.numpy_type()
if dtype == np.float32:
dtype = bfloat16
data = create_random_tensor((batch_size, *shape[1:]), dtype, rng=rng)
input_data.append(data)
placeholder_values = {
PlaceholderName(input_name): in_data
for input_name, in_data in zip(node.input_names, input_data)
}
return placeholder_values
def _tessellate_placeholder_values(
net: AwesomeNode, node: AwesomeNode, placeholder_values: PlaceholderValues
) -> dict[InputName, np.ndarray]:
pl_values = list(placeholder_values.values())
tessellated_placeholder_values = dict()
for input_node_name, pack_input_list in zip(
node.input_node_names, node.ir.graph.compile_properties.pack_parameters.values()
):
_node = net.nodes[input_node_name]
if isinstance(_node.ir.attrs, attributes.PackTransformAttrs):
pack_inputs = list()
for node_name, (input_id, tessellate_param) in (
zip(_node.input_node_names, pack_input_list)
):
pl_value = pl_values[input_id]
if tessellate_param.enable_mla:
pack_input = pl_value
else:
tessellate_node = net.nodes[node_name]
attrs = tessellate_node.ir.attrs
assert isinstance(attrs, attributes.TessellationTransformAttrs), (
f"Expected TessellationTransformAttrs, got {attrs}"
)
pack_input = op_fn.tessellation(attrs, pl_value)
pack_inputs.append(pack_input)
input_data = op_fn.pack(pack_inputs)
else:
assert len(pack_input_list) == 1
input_id = pack_input_list[0][0]
if isinstance(_node.ir.attrs, attributes.TessellationTransformAttrs):
input_data = op_fn.tessellation(_node.ir.attrs, pl_values[input_id])
else:
input_data = op_fn.reshape_to_mla_padded_2d_shape(pl_values[input_id])
tessellated_placeholder_values[InputName(input_node_name)] = input_data
return tessellated_placeholder_values
[docs]
def compile_mla_code(net: AwesomeNet, output_path: str, mla_config: MLACompilerConfig, *,
desired_batch_size: int = 1) -> int:
"""
Take in an AwesomeNet and compile the MLA supported sub-graphs. Add nodes that are
needed in order to execute the model on MLA.
- Add Tessellate node for each input node to the MLA sub-graph.
- Add Pack nodes if there are multiple inputs to the MLA sub-graph.
- Add Unpack nodes if there are multiple outputs from the MLA sub-graph.
- Add Detessellate node for each output node from the MLA sub-graph.
:param net: AwesomeNet to be compiled.
:param output_path: str. The path to where the generated .mlc files are written.
:param mla_config: Parameters controlling how to invoke the Production Compiler for all MLA model graphs.
:param desired_batch_size: The AwesomeNet inputs' desired batch size to be used in compilation.
Compilation will choose the largest size that is supported for the entire AwesomeNet and that
is no larger than this. The chosen batch size will be returned.
:return: An integer value representing batch size used by production compiler. Also, writes the generated
.mlc files to the output_path.
"""
uncompiled_nodes: List[NodeName] = copy.deepcopy(net.execution_order)
# Determine the maximal supported batch size for all MLA subnets in AwesomeNet.
max_supported_batch_size = get_awesomenet_max_batch_size(net, desired_batch_size)
compiler_batch_size = desired_batch_size if max_supported_batch_size is None else max_supported_batch_size
stage = 1
while uncompiled_nodes:
node_name = uncompiled_nodes.pop(0)
node = net.nodes[node_name]
if node.ir.backend == Backend.APU:
stage += 1
elif isinstance(node.ir, BackendIR) and node.ir.backend == Backend.MLA:
compiler_config = CompilerConfig(
make_mlc_file_name(output_path, net.name, stage),
mla_config.platform_type,
False,
mla_config.use_power_limits,
mla_config.max_power,
mla_config.compress,
mla_config.layer_norm_use_fp32_intermediates,
mla_config.rms_norm_use_fp32_intermediates)
enable_l1_check = False
save_model_graph = False
use_saved_model_graph = False
json_dir = Path(output_path) / "model_graph_json"
json_name = json_dir / f"{net.name}_stage{stage}.json"
if use_saved_model_graph:
model_graph, _ = import_model_from_json(json_name)
node.ir.graph = model_graph
else:
model_graph = node.ir.graph
set_model_compilation_properties(
compiler_config,
model_graph,
batch_size=compiler_batch_size,
tessellate_parameters=mla_config.tessellate_parameters,
enable_large_tensors=mla_config.enable_large_tensors,
validate=False
)
if save_model_graph:
json_dir.mkdir(parents=True, exist_ok=True)
export_model_to_json(model_graph, json_dir, json_name, dict())
# TODO: Add checks if the tessellation parameters are valid, due to EV plugin limitations,
# number of inputs and input shapes.
# Get placeholder values before adding pre-mla nodes.
if mla_config.verify_mlc_files:
placeholder_values = _get_placeholder_values(net, node, compiler_batch_size)
insert_pre_mla_segment_nodes(net, node, model_graph.compile_properties.pack_parameters)
insert_post_mla_segment_nodes(
net, node, model_graph.compile_properties.unpack_parameters, uncompiled_nodes
)
if mla_config.verify_mlc_files:
backend_runner = create_backend_runner(
net, mla_config, output_path, compiler_batch_size, stage
)
tessellated_placeholder_values = _tessellate_placeholder_values(
net, node, placeholder_values
)
if enable_l1_check:
from mlc.compiler.model_graph.evaluate import evaluate_model
from mlc.compiler.model_graph.l1_based import set_ofm_refs
enable_layer_ofm_chk_dict = {
".*": True,
}
sima_log_dbg("Evaluating model graph to insert l1 checks")
ofm_ref = set_ofm_refs(
model_graph, placeholder_values, False,
enable_layer_ofm_chk_dict=enable_layer_ofm_chk_dict
)
N2ACompiledBackendRunner.write_inputs_to_ifm_file(
tessellated_placeholder_values,
model_graph.compile_properties.pack_parameters,
f"{output_path}/{net.name}_stage{stage}_mla.ifm.mlc"
)
N2ACompiledBackendRunner.pack_and_write_outputs_to_ofm_chk_file(
ofm_ref, model_graph.compile_properties.unpack_parameters,
f"{output_path}/{net.name}_stage{stage}_mla.ofm_chk.mlc"
)
else:
_create_ofm_chk_mlc(
backend_runner, node, placeholder_values, tessellated_placeholder_values,
stage
)
try:
generate_l1_based_model(
model_graph.compile_properties, None, use_dram=True, num_dma_controllers=4,
l2_caching_mode=mla_config.l2_caching_mode
)
except RuntimeError as e:
# If the main process crashed, stop the evaluate process before exiting.
if backend_runner.evaluate_task is not None:
backend_runner.evaluate_task.terminate()
raise e
if mla_config.verify_mlc_files:
if enable_l1_check:
pass
elif backend_runner.evaluate_task_type == EvaluateTaskType.EXT_CMD:
if backend_runner.evaluate_task.poll() is None:
sima_log_dbg('Waiting for evaluate process to finish.')
stdout, stderr = backend_runner.evaluate_task.communicate()
for l in (stdout.decode()).splitlines():
sima_log_info(f" {l}")
for l in (stderr.decode()).splitlines():
sima_log_error(f" {l}")
if backend_runner.evaluate_task.returncode != 0:
raise RuntimeError("Error occurred when creating ofm_chk.mlc file")
elif backend_runner.evaluate_task_type == EvaluateTaskType.FORK_PROCESS:
backend_runner.evaluate_task.join()
backend_runner.execute_model_graph(stage)
# Update stage number in BackendIR
node.ir.stage = stage
stage += 1
net.topological_sort()
update_awesomenet_status(net, Status.BACKEND_IR_COMPILED)
return compiler_batch_size
#############################################
# n2a_compiler interfaces
#############################################
[docs]
class AfeToN2AOperationConverter:
"""
Abstract base class providing an interface for generation
of CompilerVertex from an AwesomeOperation.
"""
@staticmethod
@abstractmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
raise NotImplementedError("Abstract base class, use appropriate child class to generate CompilerVertex")
[docs]
class PlaceholderN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
# Get output shape
output_type = get_expected_tensor_value(output_type)
shape = output_type.shape
# Convert shape to compiler supported shape
shape = n2a_compiler_utils.shape_to_compiler_data_layout(shape, current_layout=data_layout)
# Create fake data and convert shape to HWC
dtype = n2a_compiler_utils.fix_dtype(output_type.scalar.numpy_type())
fake_data = n2a_compiler_utils.get_fake_data(shape, dtype)
# Generate vertex and return it
return n2a_compiler_utils.gen_compiler_placeholder_vertex(name, fake_data)
[docs]
class ConstantN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
if quant_attrs is None:
assert isinstance(attrs, attributes.ConstantAttrs)
data = attrs.data
else:
assert isinstance(quant_attrs, attributes.ConstantQuantAttrs)
data = quant_attrs.quant_data
# Leave out the batch dimension ('NHWC' layout is assumed), if constant is not 1D vector.
if data.ndim == 5:
data = data[0]
elif data.ndim > 1:
assert data.ndim == 4 and data.shape[0] == 1
data = data
# Generate vertex and return it
return n2a_compiler_utils.gen_compiler_constant_vertex(name, data)
[docs]
class TupleN2A(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
vertices = [v for v in input_dict.values()]
return n2a_compiler_utils.gen_compiler_tuple_vertex(name, vertices)
[docs]
class ConcatenateN2A(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
# Get input vertices
if len(input_dict) == 1:
input_vertices = input_dict['data'].inputs
else:
input_vertices = list(input_dict.values())
# Get Concat attributes
is_quantized = quant_attrs is not None
if is_quantized:
assert isinstance(quant_attrs, attributes.ConcatQuantAttrs)
axis = quant_attrs.attrs.axis
requants = quant_attrs.requants
# Verify that the scale correction is representable in the compiler.
# It is limited to int8 when the input type is int8.
assert all(t.scalar == quant_attrs.attrs.input_types[0].scalar for t in quant_attrs.attrs.input_types)
if quant_attrs.attrs.input_types[0].scalar == ScalarType.int8:
sc_correction_limit = np.iinfo(np.int8).max
else:
sc_correction_limit = np.iinfo(np.int32).max
for requant in requants:
assert isinstance(requant, FractionalZeroRequantization)
assert 0 <= requant.sc_correction <= sc_correction_limit, "Scale correction is not representable"
else:
assert isinstance(attrs, attributes.ConcatenateAttrs)
axis = attrs.axis
output_dtype = input_vertices[0].operator.shape.dtype.np_dtype
requants = [get_id_requantization(output_dtype)] * len(input_vertices)
axis = n2a_compiler_utils.axis_to_compiler_data_layout(axis, data_layout)
return n2a_compiler_utils.gen_compiler_concat_vertex(
name, input_vertices, requants, axis)
[docs]
class PreluN2A(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
datav, = input_dict.values()
is_quantized = quant_attrs is not None
if is_quantized:
assert isinstance(quant_attrs, attributes.PReluQuantAttrs)
zp = quant_attrs.data_zero_point
shift = quant_attrs.alpha_shift
alpha = quant_attrs.quant_alpha
override_rounding_type = rounding_type
else:
assert isinstance(attrs, attributes.PReluAttrs)
zp = shift = 0
alpha = attrs.alpha
override_rounding_type = RoundType.TOEVEN
alphav = n2a_compiler_utils.gen_compiler_prelu_alpha_vertex(f"{name}_alpha", alpha)
return n2a_compiler_utils.gen_compiler_prelu_vertex(
name, datav, alphav, zp, shift, override_rounding_type
)
[docs]
class TupleConcatenateN2A(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
# Modifying input dictionary so that input_dict['data'].inputs is a list
# of input vertices for ConcatenateN2A
class _DummyVertex:
inputs = [v for v in input_dict.values()]
if isinstance(quant_attrs, attributes.ConcatQuantAttrs):
return ConcatenateN2A.gen_vertex(
output_type, quant_attrs.attrs, quant_attrs,
{"data": _DummyVertex()}, data_layout, name,
rounding_type=rounding_type)
elif isinstance(attrs, attributes.ConcatenateAttrs):
return ConcatenateN2A.gen_vertex(
output_type, attrs, None,
{"data": _DummyVertex()}, data_layout, name,
rounding_type=rounding_type)
else:
assert isinstance(attrs, attributes.TupleConcatenateAttrs)
return ConcatenateN2A.gen_vertex(
output_type, attrs.concat_attrs, None,
{"data": _DummyVertex()}, data_layout, name,
rounding_type=rounding_type)
[docs]
class MultiplyN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
# Get input vertices
lhs_vertex = input_dict["lhs"]
rhs_vertex = input_dict["rhs"]
is_quantized = quant_attrs is not None
if is_quantized:
assert isinstance(quant_attrs, attributes.MultiplyQuantAttrs)
zp_lhs = quant_attrs.lhs_zero_point
zp_rhs = quant_attrs.rhs_zero_point
requant = quant_attrs.requant
intrinsic_shift = quant_attrs.intrinsic_shift
else:
zp_lhs = zp_rhs = 0
requant = get_id_requantization(bfloat16)
intrinsic_shift = 0
return n2a_compiler_utils.gen_compiler_mul_vertex(
name, lhs_vertex, rhs_vertex, zp_lhs, zp_rhs, requant, intrinsic_shift)
[docs]
class MaxPoolN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
# Get input vertices
data_vertex = input_dict['data']
# Get MaxPool attributes
is_quantized = quant_attrs is not None
if is_quantized:
assert isinstance(quant_attrs, attributes.PoolQuantAttrs)
pool_attrs = quant_attrs.pool_attrs
requant = quant_attrs.requant
else:
assert isinstance(attrs, attributes.MaxPoolAttrs)
pool_attrs = attrs
requant = None
pool_size = tuple(pool_attrs.pool_size[1:-1])
strides = tuple(pool_attrs.strides[1:-1])
padding = sum(pool_attrs.padding[1:-1], start=tuple())
if len(pool_attrs.pool_size) == 4:
pool_size = (1, *pool_size)
strides = (1, *strides)
padding = (0, 0, *padding)
output_type = get_expected_tensor_value(output_type)
output_dtype = n2a_compiler_utils.fix_dtype(output_type.scalar.numpy_type())
output_shape = n2a_compiler_utils.shape_to_compiler_data_layout(
output_type.shape, current_layout=data_layout
)
# Generate vertex and return it
return n2a_compiler_utils.gen_compiler_maxpool_vertex(
name, output_shape, data_vertex, pool_size, strides, padding, output_dtype, requant=requant)
[docs]
class AvgPoolN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
# Get input vertices
data_vertex = input_dict['data']
# Get AvgPool attributes
is_quantized = quant_attrs is not None
if is_quantized:
assert isinstance(quant_attrs, attributes.PoolQuantAttrs)
pool_attrs = quant_attrs.pool_attrs
override_rounding_type = quant_attrs.rounding_type
requant = quant_attrs.requant
pad_value = quant_attrs.pad_value
else:
assert isinstance(attrs, attributes.AvgPoolAttrs)
pool_attrs = attrs
override_rounding_type = RoundType.TOEVEN
requant = None
pad_value = 0
pool_size = tuple(pool_attrs.pool_size[1:-1])
strides = tuple(pool_attrs.strides[1:-1])
padding = sum(pool_attrs.padding[1:-1], start=tuple())
if len(pool_attrs.pool_size) == 4:
pool_size = (1, *pool_size)
strides = (1, *strides)
padding = (0, 0, *padding)
output_type = get_expected_tensor_value(output_type)
output_dtype = n2a_compiler_utils.fix_dtype(output_type.scalar.numpy_type())
output_shape = n2a_compiler_utils.shape_to_compiler_data_layout(
output_type.shape, current_layout=data_layout
)
# Generate vertex and return it
input_shape = data_vertex.operator.shape.shape
op = "global" if input_shape[:-1] == pool_size else "average"
return n2a_compiler_utils.gen_compiler_avgpool_vertex(
name, output_shape, data_vertex, op=op, pool_size=pool_size, strides=strides,
padding=padding, rounding_type=override_rounding_type, output_dtype=output_dtype,
requant=requant, pad_value=pad_value
)
[docs]
class VarianceN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
# Get input vertices
data_vertex = input_dict[InputName('data')]
mean_vertex = input_dict[InputName('mean')]
is_quantized = quant_attrs is not None
if is_quantized:
assert isinstance(quant_attrs, attributes.VarianceQuantAttrs)
requant = quant_attrs.requant
requant_var = quant_attrs.requant_var
else:
assert isinstance(attrs, attributes.VarianceAttrs)
divisor = np.float32(1.0 / prod(attrs.input_data_shape[1:-1]))
requant = Renormalization(
divisor, utils.create_and_verify_narrowing(0, RoundType.TOEVEN, bfloat16)
)
requant_var = None
return n2a_compiler_utils.gen_compiler_variance_vertex(
name, data_vertex, mean_vertex, requant, requant_var
)
[docs]
class MeanN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
"""
Limitations:
* input shape is in the NDHWC format, where N=1.
* mean is only supported along D, H and/or W axes.
* mean does not reduce tensor dimensions, output is NDHWC.
Implementation of mean relies on avgpool.
"""
# Get input vertices
data_vertex = input_dict['data']
input_shape = data_vertex.operator.output_shape
if quant_attrs is not None:
assert isinstance(quant_attrs, attributes.MeanQuantAttrs)
attrs = quant_attrs.attrs
assert isinstance(attrs, attributes.MeanAttrs)
assert attrs.keepdims, "MLA does not support MeanOp with keepdims=False."
assert not attrs.exclude, "Axes are not simplified by SimplifyAxisExcludeAttr."
assert 0 not in attrs.axis, "MLA does not support MeanOp on batch dimension."
assert len(attrs.shape) - 1 not in attrs.axis, (
"MLA does not support MeanOp on channel dimension."
)
# Convert axes to the pool size.
pool_size = tuple(
x if i in attrs.axis else 1
for i, x in enumerate(attrs.shape[1:-1], start=1)
)
if len(attrs.shape) == 4:
pool_size = (1, *pool_size)
pool_op = "global" if pool_size == input_shape[:-1] else "average"
padding = (0, ) * len(pool_size) * 2
output_shape = n2a_compiler_utils.shape_to_compiler_data_layout(
get_expected_tensor_value(output_type).shape, current_layout=data_layout
)
# Generate vertex and return it
# Hardcode the rounding_type because the it is hardcoded in the
# operation_function.py::mean
return n2a_compiler_utils.gen_compiler_avgpool_vertex(
name, output_shape, data_vertex, op=pool_op, pool_size=pool_size, strides=pool_size,
padding=padding, rounding_type=RoundType.TRUNC
)
[docs]
class ReluN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
# Get input vertices
data_vertex = input_dict["data"]
# Get zero point
is_quantized = quant_attrs is not None
if is_quantized:
assert isinstance(quant_attrs, attributes.ReluQuantAttrs)
node_zp = quant_attrs.zero_point
else:
node_zp = 0
return n2a_compiler_utils.gen_compiler_relu_vertex(
name, data_vertex, node_zp)
[docs]
class ClipN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
# Get input vertices
data_vertex = input_dict["data"]
# Get attributes
_attrs = attrs if attrs is not None else quant_attrs
assert isinstance(_attrs, (attributes.ClipAttrs, attributes.ClipQuantAttrs))
return n2a_compiler_utils.gen_compiler_clip_vertex(name, data_vertex, _attrs.a_min, _attrs.a_max)
[docs]
class LeakyReluN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
# Get input vertices
data_vertex = input_dict["data"]
is_quantized = quant_attrs is not None
if is_quantized:
assert isinstance(quant_attrs, attributes.LeakyReluCompositeQuantAttrs)
if quant_attrs.leaky_relu_uses_udf:
# Use UDF version.
assert quant_attrs.udf_quant_attrs is not None
lut_array = quant_attrs.udf_quant_attrs.lookup_table
lut_vertex = n2a_compiler_utils.gen_compiler_constant_vertex(
name + '_lut', lut_array)
return n2a_compiler_utils.gen_compiler_udf_vertex(
name, data_vertex, lut_vertex)
else:
# Use breakdown version.
assert quant_attrs.leaky_relu_quant_attrs is not None
node_zp = quant_attrs.leaky_relu_quant_attrs.zero_point
alpha = quant_attrs.leaky_relu_quant_attrs.alpha
right_shift = quant_attrs.leaky_relu_quant_attrs.right_shift
return n2a_compiler_utils.gen_compiler_leaky_relu_vertex(
name, data_vertex, alpha, node_zp, right_shift,
rounding_type)
else:
assert isinstance(attrs, attributes.LeakyReluAttrs)
alpha = attrs.alpha
node_zp = right_shift = 0
return n2a_compiler_utils.gen_compiler_leaky_relu_vertex(
name, data_vertex, alpha, node_zp, right_shift, RoundType.TOEVEN)
[docs]
class LRNN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
assert isinstance(quant_attrs, attributes.LRNQuantAttrs)
assert quant_attrs.axis == data_layout.index("C")
# Get input vertices
data_vertex = input_dict["data"]
lut_vertex = n2a_compiler_utils.gen_compiler_constant_vertex(
name + '_lut', quant_attrs.lookup_table.reshape(16, 16))
# Hardcode the rounding_type because it is hardcoded in the
# operation_function.py::lrn
return n2a_compiler_utils.gen_compiler_lrn_vertex(
name, data_vertex, lut_vertex, quant_attrs.size, quant_attrs.input_zp,
quant_attrs.lut_scale, quant_attrs.lut_zp_corr, quant_attrs.lut_sh,
quant_attrs.output_scale, quant_attrs.output_zp_corr, quant_attrs.output_sh,
RoundType.UPWARD)
[docs]
class UpsamplingN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
# Get input vertices
data_vertex = input_dict["data"]
is_quantized = quant_attrs is not None
if is_quantized:
assert isinstance(quant_attrs, attributes.UpsamplingQuantAttrs)
_attrs = quant_attrs.upsampling_attrs
zp = quant_attrs.input_zp
override_rounding_type = quant_attrs.rounding_type
else:
assert isinstance(attrs, attributes.UpsamplingAttrs)
_attrs = attrs
zp = 0
override_rounding_type = RoundType.TOEVEN
input_shape: tuple[int, ...] = data_vertex.operator.shape.shape
target_spatial_shape: tuple[int, ...] = (
input_shape[0],
input_shape[1] * _attrs.scale_h,
input_shape[2] * _attrs.scale_w
)
return n2a_compiler_utils.generate_resize_vertex(
data_vertex,
_attrs.method,
target_spatial_shape,
zp,
override_rounding_type,
mode="align_corners" if _attrs.align_corners else "half_pixel",
name=name,
)
[docs]
class ImageResize2DN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
# Get input vertices
data_vertex = input_dict["data"]
is_quantized = quant_attrs is not None
if is_quantized:
assert isinstance(quant_attrs, attributes.ImageResize2DQuantAttrs)
image_resize2d_attrs = quant_attrs.image_resize2d_attrs
zp = quant_attrs.input_zp
override_rounding_type = quant_attrs.rounding_type
else:
image_resize2d_attrs = attrs
zp = 0
override_rounding_type = RoundType.TOEVEN
assert isinstance(image_resize2d_attrs, attributes.ImageResize2DAttrs)
target_spatial_shape: tuple[int, ...] = (
data_vertex.operator.output_shape[0],
image_resize2d_attrs.size[0],
image_resize2d_attrs.size[1]
)
mode: str = image_resize2d_attrs.coordinate_transformation_mode
tf_ver = 2 if mode in ['half_pixel', 'pytorch_half_pixel', 'align_corners', 'asymmetric'] else 1
return n2a_compiler_utils.generate_resize_vertex(
data_vertex, image_resize2d_attrs.method, target_spatial_shape, zp,
override_rounding_type, tf_ver, mode, name
)
################################
# Convolution, Add, Activations
################################
[docs]
class ConvN2ACompiler(AfeToN2AOperationConverter):
"""
Converter for all supported variants of composite 2D convolution, both regular and transposed.
"""
@staticmethod
def _gen_vertex(
output_type: DataValue[TensorType],
attrs: attributes.ConvQuantAttrs | attributes.ConvAddActivationAttrs,
name: str,
data_vertex: CompilerVertex,
weight_data: np.ndarray,
bias_data: Optional[np.ndarray] = None) -> CompilerVertex:
# Get Convolution attributes
conv_attrs: attributes.ConvAttrs = attrs.conv_attrs
match conv_attrs.num_spatial_dimensions:
case 2:
data_layout = "NHWC"
stride = (1, *conv_attrs.stride)
dilation = (1, *conv_attrs.dilation)
padding = sum(conv_attrs.padding, start=(0, 0))
case 3:
data_layout = "NDHWC"
stride = conv_attrs.stride
dilation = conv_attrs.dilation
padding = sum(conv_attrs.padding, start=tuple())
case _ as unreachable:
assert_never(unreachable)
output_type = get_expected_tensor_value(output_type)
output_dtype = n2a_compiler_utils.fix_dtype(output_type.scalar.numpy_type())
output_shape = n2a_compiler_utils.shape_to_compiler_data_layout(
output_type.shape, current_layout=data_layout
)
# Check if it's a transposed convolution
is_transposed = conv_attrs.is_transposed
# Check if it is a depthwise convolution with channel multiplier == 1
is_depthwise = conv_attrs.is_depthwise_one_channel
if is_depthwise and is_transposed:
# Depthwise transposed convolution will be implemented using regular convolution.
# Flip the weight tensor to make the equivalent weights for regular convolution.
weight_data = np.flip(weight_data, axis=tuple(range(conv_attrs.num_spatial_dimensions)))
weight_data = n2a_compiler_utils.to_conv2d_weights_layout(weight_data)
# Create weights and bias constant vertices
weights_vertex = n2a_compiler_utils.gen_compiler_weight_vertex(name, weight_data)
if bias_data is not None:
bias_data = n2a_compiler_utils.cast_bias(bias_data)
bias_vertex = n2a_compiler_utils.gen_compiler_bias_vertex(name, bias_data)
else:
bias_vertex = None
# Get the input zp and requant param
is_quantized = isinstance(attrs, attributes.ConvQuantAttrs)
# Whether bfloat16 with int weights quant scheme is used
is_bfloat16_with_int_weights = (
weight_data.dtype in QuantizedWeightDtypes
and output_type.scalar == ScalarType.bfloat16
)
if is_quantized and not is_bfloat16_with_int_weights:
assert isinstance(attrs.requant, (ArithFoldedRequantization, TFLiteRequantization))
requant = afe.ir.quantization_utils.fix_requantization(attrs.requant)
input_zp = attrs.input_zp
output_zp = attrs.zero_point
msb_left_shift = attrs.msb_left_shift
else:
if is_bfloat16_with_int_weights:
assert isinstance(attrs, attributes.ConvQuantAttrs)
requant = attrs.requant
else:
requant = get_id_requantization(output_dtype)
input_zp = output_zp = 0
msb_left_shift = False
clip_range = None
match attrs.activ_attrs:
case attributes.ReluAttrs() | attributes.ReluQuantAttrs():
# Relu activation will be fused with convolution.
fused_activation = Activation.RELU
case attributes.ClipAttrs() | attributes.ClipQuantAttrs():
# Clip activation will be fused with convolution.
fused_activation = Activation.CLIP
clip_range = attrs.activ_attrs.a_min, attrs.activ_attrs.a_max
case None:
fused_activation = Activation.NONE
case _:
raise ValueError("Unhandled activation type")
gen_vertex_func: n2a_compiler_utils.ConvVertexProtocol
if is_transposed:
gen_vertex_func = n2a_compiler_utils.gen_compiler_conv2d_transpose_vertex
else:
gen_vertex_func = n2a_compiler_utils.gen_compiler_conv2d_vertex
conv_vertex = gen_vertex_func(
name, output_shape, data_vertex, weights_vertex, bias_vertex, input_zp, output_zp,
stride, padding, dilation, msb_left_shift=msb_left_shift, requant=requant,
activ=fused_activation, is_depthwise=is_depthwise, groups=conv_attrs.groups,
clip_range=clip_range
)
return conv_vertex
@staticmethod
def _check_composite_operator_support(
attrs: attributes.ConvAddActivationAttrs | attributes.ConvQuantAttrs):
"""
Checks if the composite convolution operator is supported by N2A compiler.
Currently, supported composite convolution operators are:
Conv + Add
Conv + Add + Relu
ConvTransposed + Add
ConvTransposed + Add + Relu
:param attrs: Pre-quantization attributes or quantization attributes
for the given operation
:return: None. Raises an exception if the composite operation is not supported
"""
_operator_is_supported = isinstance(attrs.activ_attrs, (attributes.ReluAttrs, attributes.ReluQuantAttrs,
attributes.ClipAttrs, attributes.ClipQuantAttrs, NoneType))
if not _operator_is_supported:
raise NotImplementedError(f"N2A Compiler does not support composite convolution operation with "
f"{type(attrs.activ_attrs)} activation attributes.")
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
# Obtain weight and bias data
is_quantized = quant_attrs is not None
if is_quantized:
assert isinstance(quant_attrs, attributes.ConvQuantAttrs)
weight_data = quant_attrs.weight_quant_data
bias_data = quant_attrs.bias_quant_data
cur_attrs = quant_attrs
else:
assert isinstance(attrs, attributes.ConvAddActivationAttrs)
weight_data = attrs.weights_attrs.data
bias_data = attrs.bias_attrs.data if attrs.bias_attrs else None
cur_attrs = attrs
# Check if composite operator is supported
ConvN2ACompiler._check_composite_operator_support(cur_attrs)
# Get input vertices
data_vertex = input_dict['data']
return ConvN2ACompiler._gen_vertex(
output_type, cur_attrs, name, data_vertex, weight_data, bias_data)
################################
# Add, Activations
################################
[docs]
class AddActivationN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
is_quantized = quant_attrs is not None
if is_quantized:
assert isinstance(quant_attrs, attributes.AddQuantAttrs)
activ_attrs = quant_attrs.activ_attrs
lhs_scale = quant_attrs.lhs_scale
rhs_scale = quant_attrs.rhs_scale
requant = quant_attrs.requant
else:
assert isinstance(attrs, attributes.AddActivationAttrs)
activ_attrs = attrs.activ_attrs
lhs_scale = rhs_scale = 1.0
requant = get_id_requantization(bfloat16)
clip_range = None
if isinstance(activ_attrs, attributes.ReluAttrs | attributes.ReluQuantAttrs):
activation = Activation.RELU
elif isinstance(activ_attrs, attributes.ClipAttrs | attributes.ClipQuantAttrs):
activation = Activation.CLIP
clip_range = activ_attrs.a_min, activ_attrs.a_max
else:
activation = Activation.NONE
# Get input vertices
lhs_vertex = input_dict["lhs"]
rhs_vertex = input_dict["rhs"]
add_vertex = n2a_compiler_utils.gen_compiler_add_subtract_vertex(
name, lhs_vertex, rhs_vertex, lhs_scale, rhs_scale, requant,
"add", activ=activation, clip_range=clip_range)
return add_vertex
[docs]
class ConstantMultiplyAddN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
if quant_attrs is not None:
return AddActivationN2ACompiler.gen_vertex(
output_type, attrs, quant_attrs, input_dict, data_layout, name, rounding_type
)
# For bfloat 16 decompose constant_multiply_add to multiply and add.
assert isinstance(attrs, attributes.ConstantMultiplyAddAttrs)
# Get input vertices
lhs_vertex = input_dict["lhs"]
rhs_vertex = input_dict["rhs"]
lhs_const_vertex = n2a_compiler_utils.gen_compiler_constant_vertex(
f"{name}/lhs_const",
attrs.in1_const_attrs.data
)
lhs_multiply_attrs = attributes.MultiplyAttrs(
attrs.scalar_type, attrs.lhs_input_shape, attrs.in1_const_attrs.data.shape
)
lhs_multiply_input_dict = {"lhs": lhs_vertex, "rhs": lhs_const_vertex}
lhs_vertex = MultiplyN2ACompiler.gen_vertex(
output_type, lhs_multiply_attrs, None, lhs_multiply_input_dict, data_layout,
f"{name}/multiply_lhs", rounding_type
)
if attrs.in2_const_attrs:
rhs_const_vertex = n2a_compiler_utils.gen_compiler_constant_vertex(
f"{name}/rhs_const",
attrs.in2_const_attrs.data
)
rhs_multiply_attrs = attributes.MultiplyAttrs(
attrs.scalar_type, attrs.lhs_input_shape, attrs.in1_const_attrs.data.shape
)
rhs_multiply_input_dict = {"lhs": rhs_vertex, "rhs": rhs_const_vertex}
rhs_vertex = MultiplyN2ACompiler.gen_vertex(
output_type, rhs_multiply_attrs, None, rhs_multiply_input_dict, data_layout,
f"{name}/multiply_rhs", rounding_type
)
lhs_scale = rhs_scale = 1.0
requant = get_id_requantization(bfloat16)
add_vertex = n2a_compiler_utils.gen_compiler_add_subtract_vertex(
name, lhs_vertex, rhs_vertex, lhs_scale, rhs_scale, requant,"add"
)
return add_vertex
[docs]
class SubtractN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
# Get input vertices
lhs_vertex = input_dict["lhs"]
rhs_vertex = input_dict["rhs"]
is_quantized = quant_attrs is not None
if is_quantized:
assert isinstance(quant_attrs, attributes.SubtractQuantAttrs)
# Get Subtract attributes
in1_scale = quant_attrs.lhs_scale
in2_scale = quant_attrs.rhs_scale
requant = quant_attrs.requant
else:
in1_scale = in2_scale = 1
requant = get_id_requantization(bfloat16)
return n2a_compiler_utils.gen_compiler_add_subtract_vertex(
name, lhs_vertex, rhs_vertex, in1_scale, in2_scale, requant,
op='sub')
[docs]
class ArgMaxN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
input_vertex = input_dict["data"]
if quant_attrs is None:
select_last_index = attrs.select_last_index
else:
select_last_index = quant_attrs.attrs.select_last_index
return n2a_compiler_utils.gen_compiler_arg_min_max_vertex(
name, input_vertex, True, select_last_index
)
[docs]
class UdfN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
def _gen_vertex(quant_attrs: attributes.UDFQuantAttrs, input_dict: Dict[InputName, Any], name: str) \
-> CompilerVertex:
data_vertex = input_dict['data']
lut_vertex = n2a_compiler_utils.gen_compiler_constant_vertex(
name + '_lut', quant_attrs.lookup_table)
return n2a_compiler_utils.gen_compiler_udf_vertex(name, data_vertex, lut_vertex)
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
if quant_attrs is None and output_type.value.scalar == ScalarType.bfloat16:
raise UserFacingException(
f"Layer {name} does not have bfloat16 support. Consider using int8 or int16"
f"quantization mode."
)
assert isinstance(quant_attrs, attributes.UDFQuantAttrs)
return UdfN2ACompiler._gen_vertex(quant_attrs, input_dict, name)
[docs]
class ErfN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
if quant_attrs is not None:
return UdfN2ACompiler.gen_vertex(output_type, attrs, quant_attrs, input_dict, data_layout, name,
rounding_type)
else:
assert attrs is not None
data_vertex = input_dict['data']
return n2a_compiler_utils.gen_compiler_erf_vertex(data_vertex, name)
[docs]
class SliceN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
data_vertex = input_dict['data']
assert isinstance(attrs, attributes.StridedSliceAttrs)
begin, end, strides = operations.expand_indices_to_shape_length(
begin=attrs.begin, end=attrs.end, strides=attrs.strides,
axes=attrs.axes, input_shape=list(attrs.input_shape))
size = get_expected_tensor_value(output_type).shape
if len(size) == 4:
# NHWC -> 1HWC.
size = (1, *size[1:])
begin = (0, *begin[1:])
strides = (1, *strides[1:])
else:
# NDHWC -> DHWC.
size = output_shape[1:]
begin = begin[1:]
strides = strides[1:]
return n2a_compiler_utils.gen_compiler_slice_vertex(
name, data_vertex, begin=begin, size=size, stride=strides
)
[docs]
class SoftmaxN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
if quant_attrs is None:
assert isinstance(attrs, attributes.SoftmaxAttrs)
non_neg_axis = convert_axis_to_non_negative(attrs.axis, len(attrs.input_shape), False)
lut_exp_data = compute_bf16_exp_lut()
lut_rec_data = compute_bf16_reciprocal_lut()
exp_zp = rec_zp = lut_input_pre_shift = output_pre_shift = None
requant_lut = requant_output = get_id_requantization(bfloat16)
else:
assert isinstance(quant_attrs, attributes.SoftmaxQuantAttrs)
non_neg_axis = convert_axis_to_non_negative(quant_attrs.axis, len(quant_attrs.input_shape), False)
lut_exp_data = quant_attrs.lookup_table_exp
lut_rec_data = quant_attrs.lookup_table_rec
exp_zp = quant_attrs.exp_zp
rec_zp = quant_attrs.rec_zp
requant_lut = quant_attrs.requant_lut
requant_output = quant_attrs.requant_output
lut_input_pre_shift=quant_attrs.lut_input_pre_shift
output_pre_shift=quant_attrs.output_pre_shift
assert non_neg_axis == data_layout.index("C"), \
f"Operator's axis ({non_neg_axis}) != channel axis ({data_layout.index('C')})"
axis = n2a_compiler_utils.axis_to_compiler_data_layout(non_neg_axis, data_layout)
# Get input vertices
data_vertex = input_dict["data"]
lut_exp_vertex = n2a_compiler_utils.gen_compiler_constant_vertex(
f"{name}_exp_lut", lut_exp_data
)
lut_rec_vertex = n2a_compiler_utils.gen_compiler_constant_vertex(
f"{name}_rec_lut", lut_rec_data
)
return n2a_compiler_utils.gen_compiler_softmax_vertex(
name, data_vertex=data_vertex, lut_exp_vertex=lut_exp_vertex,
lut_rec_vertex=lut_rec_vertex, axis=axis, exp_zp=exp_zp, rec_zp=rec_zp,
requant_lut=requant_lut, requant_output=requant_output,
lut_input_pre_shift=lut_input_pre_shift, output_pre_shift=output_pre_shift
)
[docs]
class LayerNormN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
if quant_attrs is None:
assert isinstance(attrs, attributes.LayerNormAttrs)
axis = attrs.axis
epsilon = attrs.epsilon
lut_vertex = n2a_compiler_utils.gen_compiler_constant_vertex(
f"{name}_rsqrt_lut", compute_bf16_rsqrt_lut()
)
input_shape = attrs.input_shape
zp_rsqrt = None
requant_mean = None
requant_lut_input = None
requant_output = None
else:
assert isinstance(quant_attrs, attributes.LayerNormQuantAttrs)
axis = quant_attrs.axis
epsilon = None
input_shape = quant_attrs.input_shape
lut_vertex = n2a_compiler_utils.gen_compiler_constant_vertex(
name + '_rsqrt_lut', quant_attrs.lookup_table_rsqrt.reshape(16, 16))
zp_rsqrt = quant_attrs.zp_rsqrt
requant_mean = quant_attrs.requant_mean
requant_lut_input = quant_attrs.requant_lut_input
requant_output = quant_attrs.requant_output
non_neg_axis = convert_axis_to_non_negative(axis, len(input_shape), False)
assert non_neg_axis == data_layout.index("C"), \
f"Operator's axis ({non_neg_axis}) != channel axis ({data_layout.index('C')})"
axis_compiler_data_layout = n2a_compiler_utils.axis_to_compiler_data_layout(non_neg_axis, data_layout)
# Get input vertices.
data_vertex = input_dict["data"]
return n2a_compiler_utils.gen_compiler_layer_norm_vertex(
name, data_vertex=data_vertex,
lut_vertex=lut_vertex,
axis=axis_compiler_data_layout, epsilon=epsilon,
rsqrt_zp=zp_rsqrt,
requant_mean=requant_mean,
requant_lut_input=requant_lut_input,
requant_output=requant_output)
[docs]
class RMSNormN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
# Get input vertices.
data_vertex = input_dict["data"]
axis = n2a_compiler_utils.axis_to_compiler_data_layout(data_layout.index("C"), data_layout)
if quant_attrs is None:
assert isinstance(attrs, attributes.RMSNormAttrs)
lut_data = compute_bf16_rsqrt_lut()
lut_vertex = n2a_compiler_utils.gen_compiler_constant_vertex(
f"{name}_rsqrt_lut", compute_bf16_rsqrt_lut()
)
return gen_compiler_rms_norm_vertex(
name, data_vertex, lut_vertex, axis, epsilon=attrs.epsilon
)
else:
assert isinstance(quant_attrs, attributes.RMSNormQuantAttrs)
lut_vertex = n2a_compiler_utils.gen_compiler_constant_vertex(
f"{name}_rsqrt_lut", quant_attrs.lookup_table_rsqrt
)
return gen_compiler_rms_norm_vertex(
name, data_vertex, lut_vertex, axis, zp_ifm=quant_attrs.zp_ifm,
zp_rsqrt=quant_attrs.zp_rsqrt, requant_lut=quant_attrs.requant_lut_input,
requant_output=quant_attrs.requant_output,
lut_input_pre_shift=quant_attrs.lut_input_pre_shift,
output_pre_shift=quant_attrs.output_pre_shift, epsilon=None
)
[docs]
class InstanceNormN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
# Get input vertices.
data_vertex = input_dict[InputName("data")]
mean_vertex = input_dict[InputName("mean")]
variance_vertex = input_dict[InputName("variance")]
is_quantized = quant_attrs is not None
if is_quantized:
assert isinstance(quant_attrs, attributes.InstanceNormQuantAttrs)
lut_vertex = n2a_compiler_utils.gen_compiler_constant_vertex(
f"{name}_rsqrt_lut", quant_attrs.lut_rsqrt
)
requant = quant_attrs.requant_out
zp_rsqrt = quant_attrs.zp_rsqrt
epsilon = None
else:
assert isinstance(attrs, attributes.InstanceNormAttrs)
lut_vertex = n2a_compiler_utils.gen_compiler_constant_vertex(
f"{name}_rsqrt_lut", compute_bf16_rsqrt_lut()
)
requant = None
zp_rsqrt = None
epsilon = attrs.epsilon
return n2a_compiler_utils.gen_compiler_instance_norm_vertex(name, data_vertex, mean_vertex, variance_vertex,
lut_vertex, zp_rsqrt, requant, epsilon)
[docs]
class GridSampleN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
data_vertex = input_dict["data"]
grid_vertex = input_dict["grid"]
assert quant_attrs is None and isinstance(attrs, attributes.GridSampleAttrs), \
"Only bfloat16 is supported for GridSample"
padding_mode = attrs.padding_mode
align_corners = attrs.align_corners
return n2a_compiler_utils.gen_compiler_grid_sample_vertex(
name, data_vertex, grid_vertex, padding_mode, align_corners
)
[docs]
class BatchMatmulN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
assert isinstance(quant_attrs, (attributes.BatchMatmulQuantAttrs, NoneType))
# Get input vertices
lhs_vertex = input_dict["lhs"]
rhs_vertex = input_dict["rhs"]
is_quantized = quant_attrs is not None
if is_quantized:
transpose_b = quant_attrs.attrs.transpose_b
input_zps = [quant_attrs.lhs_zp, quant_attrs.rhs_zp]
requant = quant_attrs.requant
intrinsic_shift = quant_attrs.intrinsic_shift
else:
assert isinstance(attrs, attributes.BatchMatmulAttrs)
transpose_b = attrs.transpose_b
input_zps = [0, 0]
requant = get_id_requantization("bfloat16")
intrinsic_shift = 0
return n2a_compiler_utils.gen_compiler_batch_matmul_vertex(
name, lhs=lhs_vertex, rhs=rhs_vertex, transpose_b=transpose_b,
input_zps=input_zps, requant=requant,
intrinsic_shift=intrinsic_shift
)
[docs]
class UnaryBatchMatmulN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType],
attrs: Optional[attributes.AwesomeAttributes],
quant_attrs: Optional[attributes.AwesomeQuantAttrBase],
input_dict: Dict[InputName, Any],
data_layout: str,
name: str,
rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex:
assert isinstance(quant_attrs, attributes.BatchMatmulQuantAttrs)
# Get input vertices
data_vertex = input_dict["data"]
return n2a_compiler_utils.gen_compiler_batch_matmul_vertex(
name, lhs=data_vertex, rhs=data_vertex, transpose_b=quant_attrs.attrs.transpose_b,
input_zps=[quant_attrs.lhs_zp, quant_attrs.rhs_zp], requant=quant_attrs.requant,
intrinsic_shift=quant_attrs.intrinsic_shift
)
[docs]
class SliceConcatN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType], attrs: attributes.AwesomeAttributes,
quant_attrs: attributes.AwesomeQuantAttrBase, input_dict: Dict[InputName, Any],
data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) \
-> CompilerVertex:
if quant_attrs is None:
assert isinstance(attrs, attributes.SliceConcatAttrs)
slice_attrs = attrs.slice_attrs
concat_attrs = attrs.tuple_concat_attrs.concat_attrs
requants = [get_id_requantization(bfloat16)]
else:
assert isinstance(quant_attrs, attributes.SliceConcatQuantAttrs)
slice_attrs = quant_attrs.slice_attrs
tuple_concat_attrs = quant_attrs.tuple_concat_attrs
concat_attrs = tuple_concat_attrs.attrs
requants = quant_attrs.tuple_concat_attrs.requants
split_axis = n2a_compiler_utils.axis_to_compiler_data_layout(
slice_attrs[0].axes[0], AwesomeDataLayout
)
split_block = len(slice_attrs)
concat_axis = n2a_compiler_utils.axis_to_compiler_data_layout(
concat_attrs.axis, AwesomeDataLayout
)
input_vertex = input_dict["data"]
return n2a_compiler_utils.gen_compiler_concat_vertex(
name, [input_vertex], requants=requants, axis=concat_axis, split_axis=split_axis,
split_block=split_block
)
[docs]
class BroadcastToN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType], attrs: attributes.AwesomeAttributes,
quant_attrs: attributes.AwesomeQuantAttrBase, input_dict: Dict[InputName, Any],
data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) \
-> CompilerVertex:
if quant_attrs is None:
output_shape = attrs.output_shape
else:
assert isinstance(quant_attrs, attributes.BroadcastToQuantAttrs)
output_shape = quant_attrs.output_shape
if len(output_shape) == 5:
output_shape = output_shape[1:]
else:
assert len(output_shape) == 4 and output_shape[0] == 1
input_vertex = input_dict["data"]
return n2a_compiler_utils.gen_compiler_broadcast_to_vertex(name, input_vertex, output_shape)
[docs]
class RequantizeN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType], attrs: attributes.AwesomeAttributes,
quant_attrs: attributes.AwesomeQuantAttrBase, input_dict: Dict[InputName, Any],
data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) \
-> CompilerVertex:
assert isinstance(quant_attrs, attributes.RequantizeQuantAttrs)
requant = quant_attrs.requant
input_vertex = input_dict["data"]
return n2a_compiler_utils.gen_compiler_requantization_vertex(name, input_vertex, requant)
[docs]
class DequantizationN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType], attrs: attributes.AwesomeAttributes,
quant_attrs: attributes.AwesomeQuantAttrBase, input_dict: Dict[InputName, Any],
data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) \
-> CompilerVertex:
assert isinstance(attrs, attributes.DequantizationTransformAttrs)
assert len(attrs.channel_params) == 1
scale, zp = attrs.channel_params[0]
requant = FloatRequantization(
sc_correction=np.float32(scale), zp_correction=zp,
out_dtype=attrs.output_type.numpy_type()
)
input_vertex = input_dict["data"]
return n2a_compiler_utils.gen_compiler_requantization_vertex(name, input_vertex, requant)
[docs]
class QuantizationN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType], attrs: attributes.AwesomeAttributes,
quant_attrs: attributes.AwesomeQuantAttrBase, input_dict: Dict[InputName, Any],
data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) \
-> CompilerVertex:
assert len(attrs.channel_params) == 1
scale, zp = attrs.channel_params[0]
requant = FloatRequantization(
sc_correction=np.float32(scale), zp_correction=zp,
out_dtype=attrs.output_data_type.numpy_type()
)
input_vertex = input_dict["data"]
return n2a_compiler_utils.gen_compiler_requantization_vertex(name, input_vertex, requant)
[docs]
class TransposeN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType], attrs: attributes.AwesomeAttributes,
quant_attrs: attributes.AwesomeQuantAttrBase, input_dict: Dict[InputName, Any],
data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) \
-> CompilerVertex:
assert isinstance(attrs, attributes.TransposeAttrs)
# Axes to HWC
if len(attrs.axes) == 4:
new_axes = (0, *attrs.axes[1:])
else:
assert len(attrs.axes) == 5
new_axes = tuple([axis - 1 for axis in attrs.axes[1:]])
input_vertex = input_dict["data"]
return n2a_compiler_utils.gen_compiler_transpose_vertex(name, input_vertex, perm=new_axes)
[docs]
class DepthToSpaceN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(
output_type: DataValue[TensorType], attrs: attributes.AwesomeAttributes,
quant_attrs: attributes.AwesomeQuantAttrBase, input_dict: Dict[InputName, Any],
data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD
) -> CompilerVertex:
assert isinstance(attrs, attributes.DepthToSpaceAttrs)
input_vertex = input_dict["data"]
return n2a_compiler_utils.gen_compiler_depth_to_space_vertex(name, input_vertex, block_size=attrs.block_size, mode=attrs.mode)
[docs]
class DivideN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(output_type: DataValue[TensorType], attrs: attributes.AwesomeAttributes,
quant_attrs: attributes.AwesomeQuantAttrBase, input_dict: Dict[InputName, Any],
data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) \
-> CompilerVertex:
assert isinstance(quant_attrs, attributes.DivideQuantAttrs)
lhs_vertex = input_dict["lhs"]
rhs_vertex = input_dict["rhs"]
reciprocal_vertex = UdfN2ACompiler.gen_vertex(
output_type, None, quant_attrs.udf_attrs, {'data': rhs_vertex}, data_layout,
f'{name}_rhs_reciprocal', rounding_type
)
return MultiplyN2ACompiler.gen_vertex(
output_type, None, quant_attrs.multiply_attrs, {'lhs': lhs_vertex, 'rhs': reciprocal_vertex},
data_layout, name, rounding_type
)
[docs]
class SigmoidN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(
output_type: DataValue[TensorType], attrs: attributes.AwesomeAttributes,
quant_attrs: attributes.AwesomeQuantAttrBase | None, input_dict: Dict[InputName, Any],
data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD
) -> CompilerVertex:
if isinstance(quant_attrs, attributes.UDFQuantAttrs):
return UdfN2ACompiler.gen_vertex(
output_type, attrs, quant_attrs, input_dict, data_layout, name, rounding_type
)
else:
assert quant_attrs is None and isinstance(attrs, attributes.UDFAttrs)
lut_exp = n2a_compiler_utils.gen_compiler_constant_vertex(
f"{name}_lut_exp", compute_bf16_exp_lut()
)
lut_rec = n2a_compiler_utils.gen_compiler_constant_vertex(
f"{name}_lut_rec", compute_bf16_reciprocal_lut()
)
return n2a_compiler_utils.gen_compiler_sigmoid_vertex(
name, input_dict["data"], lut_exp, lut_rec
)
[docs]
class SwishN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(
output_type: DataValue[TensorType], attrs: attributes.AwesomeAttributes,
quant_attrs: attributes.AwesomeQuantAttrBase | None, input_dict: Dict[InputName, Any],
data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD
) -> CompilerVertex:
if isinstance(quant_attrs, attributes.UDFQuantAttrs):
return UdfN2ACompiler.gen_vertex(
output_type, attrs, quant_attrs, input_dict, data_layout, name, rounding_type
)
else:
assert quant_attrs is None and isinstance(attrs, attributes.UDFAttrs)
sigmoid_vertex = SigmoidN2ACompiler.gen_vertex(
output_type, attrs, None, input_dict, data_layout, f"{name}/sigmoid",
rounding_type
)
multiply_attrs = attributes.MultiplyAttrs(
attrs.scalar_type, attrs.input_shape, attrs.input_shape
)
multiply_input_dict = {"lhs": input_dict["data"], "rhs": sigmoid_vertex}
return MultiplyN2ACompiler.gen_vertex(
output_type, multiply_attrs, None, multiply_input_dict, data_layout,
f"{name}/multiply", rounding_type
)
[docs]
class TanhN2ACompiler(AfeToN2AOperationConverter):
@staticmethod
[docs]
def gen_vertex(
output_type: DataValue[TensorType], attrs: attributes.AwesomeAttributes,
quant_attrs: attributes.AwesomeQuantAttrBase | None, input_dict: Dict[InputName, Any],
data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD
) -> CompilerVertex:
if isinstance(quant_attrs, attributes.UDFQuantAttrs):
return UdfN2ACompiler.gen_vertex(
output_type, attrs, quant_attrs, input_dict, data_layout, name, rounding_type
)
else:
# For bfloat16 rewrite tanh(x) as 2*sigmoid(2x) - 1
assert quant_attrs is None and isinstance(attrs, attributes.UDFAttrs)
constant_vertex_1 = n2a_compiler_utils.gen_compiler_constant_vertex(
f"{name}/const_1",
np.array([2]).astype(bfloat16)
)
multiply_attrs_1 = attributes.MultiplyAttrs(
attrs.scalar_type, attrs.input_shape, (1,)
)
multiply_input_dict_1 = {"lhs": input_dict["data"], "rhs": constant_vertex_1}
multiply_vertex_1 = MultiplyN2ACompiler.gen_vertex(
output_type, multiply_attrs_1, None, multiply_input_dict_1, data_layout,
f"{name}/multiply_1", rounding_type
)
sigmoid_vertex = SigmoidN2ACompiler.gen_vertex(
output_type, attrs, None, {"data": multiply_vertex_1},
data_layout, f"{name}/sigmoid",
rounding_type
)
constant_vertex_2 = n2a_compiler_utils.gen_compiler_constant_vertex(
f"{name}/const_2",
np.array([2]).astype(bfloat16)
)
multiply_attrs_2 = attributes.MultiplyAttrs(
attrs.scalar_type, attrs.input_shape, (1,)
)
multiply_input_dict_2 = {"lhs": sigmoid_vertex, "rhs": constant_vertex_2}
multiply_vertex_2 = MultiplyN2ACompiler.gen_vertex(
output_type, multiply_attrs_2, None, multiply_input_dict_2, data_layout,
f"{name}/multiply_2", rounding_type
)
constant_vertex_3 = n2a_compiler_utils.gen_compiler_constant_vertex(
f"{name}/const_3",
np.array([1]).astype(bfloat16)
)
subtract_attrs = attributes.SubtractAttrs(
attrs.scalar_type, attrs.input_shape, (1,)
)
subtract_input_dict = {"lhs": multiply_vertex_2, "rhs": constant_vertex_3}
return SubtractN2ACompiler.gen_vertex(
output_type, subtract_attrs, None, subtract_input_dict, data_layout,
f"{name}/subtract", rounding_type
)
# Dictionary used in converting Awesome Operations to Compiler Vertices
[docs]
AFE_TO_N2ACOMPILER_OPERATOR_DICT: Dict[Type[operations.AwesomeOperation],
Type[AfeToN2AOperationConverter]] = {
operations.PlaceholderOp: PlaceholderN2ACompiler,
operations.ConstantOp: ConstantN2ACompiler,
operations.TupleOp: TupleN2A,
operations.ConcatenateOp: ConcatenateN2A,
operations.MaxPool2DOp: MaxPoolN2ACompiler,
operations.AvgPool2DOp: AvgPoolN2ACompiler,
operations.AvgPool3DOp: AvgPoolN2ACompiler,
operations.MeanOp: MeanN2ACompiler,
operations.LRNOp: LRNN2ACompiler,
operations.LeakyReluCompositeOp: LeakyReluN2ACompiler,
operations.UpsamplingOp: UpsamplingN2ACompiler,
operations.ImageResize2DOp: ImageResize2DN2ACompiler,
operations.MultiplyOp: MultiplyN2ACompiler,
operations.SqrtOp: UdfN2ACompiler,
operations.RsqrtOp: UdfN2ACompiler,
operations.TanhOp: TanhN2ACompiler,
operations.SigmoidOp: SigmoidN2ACompiler,
operations.LogOp: UdfN2ACompiler,
operations.Log2Op: UdfN2ACompiler,
operations.Log10Op: UdfN2ACompiler,
operations.ExpOp: UdfN2ACompiler,
operations.SwishOp: SwishN2ACompiler,
operations.EluOp: UdfN2ACompiler,
operations.PReluOp: PreluN2A,
operations.ReluOp: ReluN2ACompiler,
operations.SoftplusOp: UdfN2ACompiler,
operations.ErfOp: ErfN2ACompiler,
operations.HardSwishOp: UdfN2ACompiler,
operations.HardSigmoidOp: UdfN2ACompiler,
operations.SubtractOp: SubtractN2ACompiler,
operations.ArgMaxOp: ArgMaxN2ACompiler,
operations.SoftmaxOp: SoftmaxN2ACompiler,
operations.RequantizeOp: RequantizeN2ACompiler,
operations.TransposeOp: TransposeN2ACompiler,
operations.DepthToSpaceOp: DepthToSpaceN2ACompiler,
operations.DequantizationTransformOp: DequantizationN2ACompiler,
operations.QuantizationTransformOp: QuantizationN2ACompiler,
operations.VarianceOp: VarianceN2ACompiler,
operations.ClipOp: ClipN2ACompiler,
# COMPOSITE OPERATIONS
operations.ConvAddActivationOp: ConvN2ACompiler,
operations.AddActivationOp: AddActivationN2ACompiler,
operations.TupleConcatenateOp: TupleConcatenateN2A,
operations.ConstantMultiplyAddOp: ConstantMultiplyAddN2ACompiler,
operations.StridedSliceOp: SliceN2ACompiler,
operations.LayerNormOp: LayerNormN2ACompiler,
operations.BatchMatmulOp: BatchMatmulN2ACompiler,
operations.UnaryBatchMatmulOp: UnaryBatchMatmulN2ACompiler,
operations.SliceConcatOp: SliceConcatN2ACompiler,
operations.BroadcastToOp: BroadcastToN2ACompiler,
operations.DivideOp: DivideN2ACompiler,
operations.RMSNormOp: RMSNormN2ACompiler,
operations.GeluOp: UdfN2ACompiler,
operations.GridSampleOp: GridSampleN2ACompiler,
operations.InstanceNormOp: InstanceNormN2ACompiler
}
[docs]
def get_afe_to_n2a_converter(node: AwesomeNode) -> Type[AfeToN2AOperationConverter]:
"""
Utility function for getting the AFE node to CompilerVertex converter from the converters' dictionary.
:param node: AwesomeNode that is to be converted to CompilerVertex.
:return: The converter type from the AFE_TO_N2ACOMPILER_OPERATOR_DICT.
"""
_converters: Dict[Type[operations.AwesomeOperation], Type[AfeToN2AOperationConverter]] = \
AFE_TO_N2ACOMPILER_OPERATOR_DICT
# Get converter for the SimaIR operator
if type(node.ir.operation) not in _converters.keys():
raise TypeError(f"n2a compiler does not support operation {node.ir.operation.__class__.__name__} "
f"in {node.name} node!")
return _converters[type(node.ir.operation)]