Source code for afe.backends.mla.afe_to_n2a_compiler.n2a_compiler_operations

#########################################################
# Copyright (C) 2020 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Joey Chou
#########################################################
from abc import abstractmethod
from collections import OrderedDict
import copy
from dataclasses import dataclass
from math import prod
from pathlib import Path

import numpy as np
from typing import Dict, Any, Optional, Type, List, Union, Set
from typing_extensions import assert_never

import afe.ir.quantization_utils
from afe._tvm._utils import convert_axis_to_non_negative
from afe.backends.mla.afe_to_n2a_compiler.n2a_backend_runner import N2ABackendSimulator, N2ACompiledBackendRunner, EvaluateTaskType
from afe.backends.mla.afe_to_n2a_compiler.n2a_compiler_utils import make_mlc_file_name, gen_compiler_rms_norm_vertex

from afe.backends import Backend, BackendIR
from afe.backends.mla.afe_to_n2a_compiler import n2a_compiler_utils
from afe.backends.mla.afe_to_n2a_compiler.defines import (
    Activation, ArithFoldedRequantization, CompilerVertex, QuantizedWeightDtypes,
    FractionalZeroRequantization, L2CachingMode, ModelGraph, RoundType, TFLiteRequantization,
    TessellateParameters, TuplePlaceholderOperator, bfloat16, compute_bf16_exp_lut,
    compute_bf16_reciprocal_lut, compute_bf16_rsqrt_lut, create_random_tensor, export_model_to_json,
    generate_l1_based_model, get_id_requantization, get_max_batch_size_and_quadrants_used,
    import_model_from_json, set_model_compilation_properties
)
from afe.backends.mla.afe_to_n2a_compiler.insert_nodes import (
    insert_pre_mla_segment_nodes, insert_post_mla_segment_nodes
)
import afe.ir.attributes as attributes
from afe.ir.defines import (
    InputName, AwesomeDataLayout, AwesomeDataLayout5D, NoneType, NodeName, Status, DataValue,
    get_expected_tensor_value
)
from afe.ir.net import AwesomeNet, is_one_mla_segment_net, update_awesomenet_status
from afe.ir.node import AwesomeNode, node_is_awesomenet
import afe.ir.operations as operations
import afe.ir.operation_functions as op_fn
from afe.ir.quantization_utils import quantization_data_value_to_output_list
from afe.ir.tensor_type import TensorType, ScalarType
import afe.ir.utils as utils
from mlc.compiler.model_graph import PlaceholderName, PlaceholderValues
from mlc.test_util.test_context import CompilerConfig
from sima_utils.common import Platform
from sima_utils.logging.sima_logger import sima_log_dbg, sima_log_error, sima_log_info, UserFacingException

from ml_kernels.requantization import FloatRequantization, Renormalization


@dataclass
[docs] class MLACompilerConfig: """ Parameters that control how to run the Production Compiler for a model. Parameters for a specific subgraph or MLC file do not belong here. They should be passed when calling the production compiler, instead. :param tessellate_parameters: Dictionary containing information on tessellation parameters for inputs and outputs of the MLA subgraph. For more information on how parameters are interpreted, please see insert_nodes.py. :param verify_mlc_files: Whether to verify correctness of the generated mlc files. Default is True. It should be disabled only in test cases which also execute the MLA nodes using N2ABackendSimulator. :param enable_large_tensors: If true, the MLA will handle large tensors, otherwise large tensors will raise an exception. :param l2_caching_mode: Parameters controlling L2 caching in compiler. :param platform_type: Target MLA architecture type. :param use_power_limits: If true, the compiler will schedule instructions to conform to power limits. :param max_power: Set to a positive float value to override default max power when power limits are used. :param compress: If true, the compiler will compress the data. :param layer_norm_use_fp32_intermediates: Use FP32 intermediate tensors in BF16 LayerNorm kernel. :param rms_norm_use_fp32_intermediates: Use FP32 intermediate tensors in BF16 RMSNorm kernel. """
[docs] tessellate_parameters: Optional[TessellateParameters] = None
[docs] verify_mlc_files: bool = True
[docs] enable_large_tensors: bool = True
[docs] l2_caching_mode: L2CachingMode = L2CachingMode.NONE
[docs] platform_type: Platform = Platform.GEN1
[docs] use_power_limits: bool = False
[docs] max_power: float | None = None
[docs] compress: bool = True
[docs] layer_norm_use_fp32_intermediates: bool = False
[docs] rms_norm_use_fp32_intermediates: bool = False
def _gen_modelgraph_vertex(node: AwesomeNode, vertices: Dict[str, CompilerVertex], data_layout: str, rounding_type: str = RoundType.UPWARD) -> CompilerVertex: """ Generate CompilerVertex that is associated with specified AwesomeNode. :param node: AwesomeNode that will be converted. :param vertices: Vertices :param data_layout: Data layout of the input. :param rounding_type: Method of rounding. :return: CompilerVertex """ # PlaceholderOp does not have inputs. if isinstance(node.ir.operation, operations.PlaceholderOp): input_vertices = {} else: input_vertices = OrderedDict({input_name: vertices[input_node_name] for input_name, input_node_name in zip(node.input_names, node.input_node_names)}) afe_to_n2a_converter = get_afe_to_n2a_converter(node) return afe_to_n2a_converter.gen_vertex(node.get_type().output, node.ir.attrs, node.ir.quant_attrs, input_vertices, data_layout, node.name, rounding_type)
[docs] def create_modelgraph( net: AwesomeNet, rounding_type: RoundType = RoundType.UPWARD) -> ModelGraph: """ Convert a quantized AwesomeNet to a ModelGraph :param net: AwesomeNet. A quantized AwesomeNet :param rounding_type: str. Rounding method used in the model graph. :return: A ModelGraph """ # Determine model layout based on input dimension input_node_names = [node_name for node_name in net.input_node_names] ndim = len(net.nodes[input_node_names[0]].ir.get_attrs().type.shape) assert all(len(net.nodes[name].ir.get_attrs().type.shape) == ndim for name in input_node_names), \ f"Expect uniform 4D or 5D for Model Graph creation" data_layout = AwesomeDataLayout if ndim == 4 else AwesomeDataLayout5D vertices = OrderedDict() for node_name in net._execution_order: node = net.nodes[node_name] # TODO: Remove data_layout from compiler interface once the ConvertLayout is fully implemented vertices[node_name] = _gen_modelgraph_vertex(node, vertices, data_layout=data_layout, rounding_type=rounding_type) # Collect the input vertices for the ModelGraph. input_vertices = [vertices[node_name] for node_name in net.input_node_names] # Create output vertices for the ModelGraph. If the last vertex of the net is a Tuple vertex, # use its inputs to create the output vertices; otherwise, use the last vertex to create the # output vertex. output_vertex = list(vertices.values())[-1] if isinstance(output_vertex._operator, TuplePlaceholderOperator): output_vertices = [ n2a_compiler_utils.gen_compiler_output_vertex(v) for v in output_vertex._inputs ] else: output_vertices = [n2a_compiler_utils.gen_compiler_output_vertex(output_vertex)] return ModelGraph(input_vertices, output_vertices)
[docs] def translate_sub_awesome_net_to_modelgraph(net: AwesomeNet, force_update_status: bool = False) -> Union[AwesomeNet, ModelGraph]: """ Take in an AwesomeNet and translate the MLA supported sub-graphs to ModelGraphs. :param net: AwesomeNet :return: An AwesomeNet that contains ModelGraph(s) or ModelGraph if there is no sub-graphs that can be translated to MLA backend IR. """ # Try to convert the entire AwesomeNet to ModelGraph if the given AwesomeNet # does not contain sub-graphs that can be translated to MLA backend IR. if is_one_mla_segment_net(net): return create_modelgraph(net) for node_name in net.execution_order: node = net.nodes[node_name] if node_is_awesomenet(node): # This function only supports MLA backend assert node.ir.backend == Backend.MLA # Use sub-graph's placeholders' input quantization parameters as # the input quantization parameters of the backend graph input_layer_bits = [] input_zps = [] input_scales = [] for name in node.ir.input_node_names: input_quantization = \ attributes.get_data_value_quant_result_scale_with_dummy(node.ir.nodes[name].ir.calib_attrs.quant) scale, zero_point, layer_bits, _, _ = quantization_data_value_to_output_list(input_quantization) input_layer_bits.append(layer_bits) input_zps.append(zero_point) input_scales.append(scale) net_type = node.get_type() node.ir = BackendIR(create_modelgraph(node.ir), net_type, Backend.MLA) update_awesomenet_status(net, Status.BACKEND_IR_LOWERED, force_update_status) return net
[docs] def create_backend_runner(net: AwesomeNet, mla_config: MLACompilerConfig, output_path: str, batch_size: int, stage: int) -> N2ABackendSimulator: """ Verify mlc files using N2ABackendSimulator. :param net: AwesomeNet. :param node: AwesomeNode. :param placeholder_values: Dictionary containing the untessellated inputs of the MLA subgraph. Needed to generate reference output in order to produce the mlc_chk file. :param output_path: Output path for .mlc files. :param batch_size: Batch size for which the model has been compiled. :param stage: Stage number of the graph. """ def _report_sim_failure(info: str) -> None: raise RuntimeError(f"Simulation has detected a failure in {info}") return N2ABackendSimulator(out_dir=output_path, layout='NHWC', model_name=net.name, batch_size=batch_size, platform_type=mla_config.platform_type, report_sim_failure=_report_sim_failure, evaluate_task_type=EvaluateTaskType.EXT_CMD)
def _create_ofm_chk_mlc( backend_runner: N2ABackendSimulator, node: AwesomeNode, placeholder_values: PlaceholderValues, tessellated_placeholder_values: dict[InputName, np.ndarray], stage: int ): backend_runner.evaluate_model_graph( node.ir.graph, tessellated_placeholder_values, untessellated_inputs=placeholder_values, stage=stage )
[docs] def get_awesomenet_max_batch_size(net: AwesomeNet, desired_batch_size: int) -> Optional[int]: """ Traverses the nodes of an AwesomeNet and finds the maximal batch size supported by compiler for all MLA subnets as a minimal value of maximal batch sizes supported by compiler for each individual MLA subnet. If there is no MLA subnets in an AwesomeNet, None value is returned. :param net: An AwesomeNet which maximal supported batch size is to be determined. :param desired_batch_size: The AwesomeNet inputs' desired batch size to be used in compilation. Compilation will choose the largest size that is supported for the entire AwesomeNet and that is no larger than this. :return: An integer value representing maximal supported batch size for an AwesomeNet by compiler, not larger than the desired batch size, or None if no MLA segments are present in an AwesomeNet. """ nodes_max_batch_size: Set[int] = set() for node_name in net.execution_order: node = net.nodes[node_name] if isinstance(node.ir, BackendIR): max_batch_size, _ = get_max_batch_size_and_quadrants_used(node.ir.graph, desired_batch_size=desired_batch_size) nodes_max_batch_size.add(max_batch_size) return min(nodes_max_batch_size) if nodes_max_batch_size else None
def _get_placeholder_values( net: AwesomeNet, node: AwesomeNode, batch_size: int ) -> tuple[PlaceholderValues, dict[InputName, np.ndarray]]: """ Generate placeholder values that will be used in _verify_mlc_files function. :param net: AwesomeNet :param node: AwesomeNode :return: PlaceholderValues and tessellated placeholder values. """ input_types = [ net.nodes[input_node_name].get_type().output for input_node_name in node.input_node_names ] input_data = list() rng = np.random.default_rng(2) for input_type in input_types: in_type = get_expected_tensor_value(input_type) shape = in_type.shape # FIXME: Force to use bfloat16 if float32 is provided. dtype = in_type.scalar.numpy_type() if dtype == np.float32: dtype = bfloat16 data = create_random_tensor((batch_size, *shape[1:]), dtype, rng=rng) input_data.append(data) placeholder_values = { PlaceholderName(input_name): in_data for input_name, in_data in zip(node.input_names, input_data) } return placeholder_values def _tessellate_placeholder_values( net: AwesomeNode, node: AwesomeNode, placeholder_values: PlaceholderValues ) -> dict[InputName, np.ndarray]: pl_values = list(placeholder_values.values()) tessellated_placeholder_values = dict() for input_node_name, pack_input_list in zip( node.input_node_names, node.ir.graph.compile_properties.pack_parameters.values() ): _node = net.nodes[input_node_name] if isinstance(_node.ir.attrs, attributes.PackTransformAttrs): pack_inputs = list() for node_name, (input_id, tessellate_param) in ( zip(_node.input_node_names, pack_input_list) ): pl_value = pl_values[input_id] if tessellate_param.enable_mla: pack_input = pl_value else: tessellate_node = net.nodes[node_name] attrs = tessellate_node.ir.attrs assert isinstance(attrs, attributes.TessellationTransformAttrs), ( f"Expected TessellationTransformAttrs, got {attrs}" ) pack_input = op_fn.tessellation(attrs, pl_value) pack_inputs.append(pack_input) input_data = op_fn.pack(pack_inputs) else: assert len(pack_input_list) == 1 input_id = pack_input_list[0][0] if isinstance(_node.ir.attrs, attributes.TessellationTransformAttrs): input_data = op_fn.tessellation(_node.ir.attrs, pl_values[input_id]) else: input_data = op_fn.reshape_to_mla_padded_2d_shape(pl_values[input_id]) tessellated_placeholder_values[InputName(input_node_name)] = input_data return tessellated_placeholder_values
[docs] def compile_mla_code(net: AwesomeNet, output_path: str, mla_config: MLACompilerConfig, *, desired_batch_size: int = 1) -> int: """ Take in an AwesomeNet and compile the MLA supported sub-graphs. Add nodes that are needed in order to execute the model on MLA. - Add Tessellate node for each input node to the MLA sub-graph. - Add Pack nodes if there are multiple inputs to the MLA sub-graph. - Add Unpack nodes if there are multiple outputs from the MLA sub-graph. - Add Detessellate node for each output node from the MLA sub-graph. :param net: AwesomeNet to be compiled. :param output_path: str. The path to where the generated .mlc files are written. :param mla_config: Parameters controlling how to invoke the Production Compiler for all MLA model graphs. :param desired_batch_size: The AwesomeNet inputs' desired batch size to be used in compilation. Compilation will choose the largest size that is supported for the entire AwesomeNet and that is no larger than this. The chosen batch size will be returned. :return: An integer value representing batch size used by production compiler. Also, writes the generated .mlc files to the output_path. """ uncompiled_nodes: List[NodeName] = copy.deepcopy(net.execution_order) # Determine the maximal supported batch size for all MLA subnets in AwesomeNet. max_supported_batch_size = get_awesomenet_max_batch_size(net, desired_batch_size) compiler_batch_size = desired_batch_size if max_supported_batch_size is None else max_supported_batch_size stage = 1 while uncompiled_nodes: node_name = uncompiled_nodes.pop(0) node = net.nodes[node_name] if node.ir.backend == Backend.APU: stage += 1 elif isinstance(node.ir, BackendIR) and node.ir.backend == Backend.MLA: compiler_config = CompilerConfig( make_mlc_file_name(output_path, net.name, stage), mla_config.platform_type, False, mla_config.use_power_limits, mla_config.max_power, mla_config.compress, mla_config.layer_norm_use_fp32_intermediates, mla_config.rms_norm_use_fp32_intermediates) enable_l1_check = False save_model_graph = False use_saved_model_graph = False json_dir = Path(output_path) / "model_graph_json" json_name = json_dir / f"{net.name}_stage{stage}.json" if use_saved_model_graph: model_graph, _ = import_model_from_json(json_name) node.ir.graph = model_graph else: model_graph = node.ir.graph set_model_compilation_properties( compiler_config, model_graph, batch_size=compiler_batch_size, tessellate_parameters=mla_config.tessellate_parameters, enable_large_tensors=mla_config.enable_large_tensors, validate=False ) if save_model_graph: json_dir.mkdir(parents=True, exist_ok=True) export_model_to_json(model_graph, json_dir, json_name, dict()) # TODO: Add checks if the tessellation parameters are valid, due to EV plugin limitations, # number of inputs and input shapes. # Get placeholder values before adding pre-mla nodes. if mla_config.verify_mlc_files: placeholder_values = _get_placeholder_values(net, node, compiler_batch_size) insert_pre_mla_segment_nodes(net, node, model_graph.compile_properties.pack_parameters) insert_post_mla_segment_nodes( net, node, model_graph.compile_properties.unpack_parameters, uncompiled_nodes ) if mla_config.verify_mlc_files: backend_runner = create_backend_runner( net, mla_config, output_path, compiler_batch_size, stage ) tessellated_placeholder_values = _tessellate_placeholder_values( net, node, placeholder_values ) if enable_l1_check: from mlc.compiler.model_graph.evaluate import evaluate_model from mlc.compiler.model_graph.l1_based import set_ofm_refs enable_layer_ofm_chk_dict = { ".*": True, } sima_log_dbg("Evaluating model graph to insert l1 checks") ofm_ref = set_ofm_refs( model_graph, placeholder_values, False, enable_layer_ofm_chk_dict=enable_layer_ofm_chk_dict ) N2ACompiledBackendRunner.write_inputs_to_ifm_file( tessellated_placeholder_values, model_graph.compile_properties.pack_parameters, f"{output_path}/{net.name}_stage{stage}_mla.ifm.mlc" ) N2ACompiledBackendRunner.pack_and_write_outputs_to_ofm_chk_file( ofm_ref, model_graph.compile_properties.unpack_parameters, f"{output_path}/{net.name}_stage{stage}_mla.ofm_chk.mlc" ) else: _create_ofm_chk_mlc( backend_runner, node, placeholder_values, tessellated_placeholder_values, stage ) try: generate_l1_based_model( model_graph.compile_properties, None, use_dram=True, num_dma_controllers=4, l2_caching_mode=mla_config.l2_caching_mode ) except RuntimeError as e: # If the main process crashed, stop the evaluate process before exiting. if backend_runner.evaluate_task is not None: backend_runner.evaluate_task.terminate() raise e if mla_config.verify_mlc_files: if enable_l1_check: pass elif backend_runner.evaluate_task_type == EvaluateTaskType.EXT_CMD: if backend_runner.evaluate_task.poll() is None: sima_log_dbg('Waiting for evaluate process to finish.') stdout, stderr = backend_runner.evaluate_task.communicate() for l in (stdout.decode()).splitlines(): sima_log_info(f" {l}") for l in (stderr.decode()).splitlines(): sima_log_error(f" {l}") if backend_runner.evaluate_task.returncode != 0: raise RuntimeError("Error occurred when creating ofm_chk.mlc file") elif backend_runner.evaluate_task_type == EvaluateTaskType.FORK_PROCESS: backend_runner.evaluate_task.join() backend_runner.execute_model_graph(stage) # Update stage number in BackendIR node.ir.stage = stage stage += 1 net.topological_sort() update_awesomenet_status(net, Status.BACKEND_IR_COMPILED) return compiler_batch_size
############################################# # n2a_compiler interfaces #############################################
[docs] class AfeToN2AOperationConverter: """ Abstract base class providing an interface for generation of CompilerVertex from an AwesomeOperation. """ @staticmethod @abstractmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: raise NotImplementedError("Abstract base class, use appropriate child class to generate CompilerVertex")
[docs] class PlaceholderN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: # Get output shape output_type = get_expected_tensor_value(output_type) shape = output_type.shape # Convert shape to compiler supported shape shape = n2a_compiler_utils.shape_to_compiler_data_layout(shape, current_layout=data_layout) # Create fake data and convert shape to HWC dtype = n2a_compiler_utils.fix_dtype(output_type.scalar.numpy_type()) fake_data = n2a_compiler_utils.get_fake_data(shape, dtype) # Generate vertex and return it return n2a_compiler_utils.gen_compiler_placeholder_vertex(name, fake_data)
[docs] class ConstantN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: if quant_attrs is None: assert isinstance(attrs, attributes.ConstantAttrs) data = attrs.data else: assert isinstance(quant_attrs, attributes.ConstantQuantAttrs) data = quant_attrs.quant_data # Leave out the batch dimension ('NHWC' layout is assumed), if constant is not 1D vector. if data.ndim == 5: data = data[0] elif data.ndim > 1: assert data.ndim == 4 and data.shape[0] == 1 data = data # Generate vertex and return it return n2a_compiler_utils.gen_compiler_constant_vertex(name, data)
[docs] class TupleN2A(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: vertices = [v for v in input_dict.values()] return n2a_compiler_utils.gen_compiler_tuple_vertex(name, vertices)
[docs] class ConcatenateN2A(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: # Get input vertices if len(input_dict) == 1: input_vertices = input_dict['data'].inputs else: input_vertices = list(input_dict.values()) # Get Concat attributes is_quantized = quant_attrs is not None if is_quantized: assert isinstance(quant_attrs, attributes.ConcatQuantAttrs) axis = quant_attrs.attrs.axis requants = quant_attrs.requants # Verify that the scale correction is representable in the compiler. # It is limited to int8 when the input type is int8. assert all(t.scalar == quant_attrs.attrs.input_types[0].scalar for t in quant_attrs.attrs.input_types) if quant_attrs.attrs.input_types[0].scalar == ScalarType.int8: sc_correction_limit = np.iinfo(np.int8).max else: sc_correction_limit = np.iinfo(np.int32).max for requant in requants: assert isinstance(requant, FractionalZeroRequantization) assert 0 <= requant.sc_correction <= sc_correction_limit, "Scale correction is not representable" else: assert isinstance(attrs, attributes.ConcatenateAttrs) axis = attrs.axis output_dtype = input_vertices[0].operator.shape.dtype.np_dtype requants = [get_id_requantization(output_dtype)] * len(input_vertices) axis = n2a_compiler_utils.axis_to_compiler_data_layout(axis, data_layout) return n2a_compiler_utils.gen_compiler_concat_vertex( name, input_vertices, requants, axis)
[docs] class PreluN2A(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: datav, = input_dict.values() is_quantized = quant_attrs is not None if is_quantized: assert isinstance(quant_attrs, attributes.PReluQuantAttrs) zp = quant_attrs.data_zero_point shift = quant_attrs.alpha_shift alpha = quant_attrs.quant_alpha override_rounding_type = rounding_type else: assert isinstance(attrs, attributes.PReluAttrs) zp = shift = 0 alpha = attrs.alpha override_rounding_type = RoundType.TOEVEN alphav = n2a_compiler_utils.gen_compiler_prelu_alpha_vertex(f"{name}_alpha", alpha) return n2a_compiler_utils.gen_compiler_prelu_vertex( name, datav, alphav, zp, shift, override_rounding_type )
[docs] class TupleConcatenateN2A(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: # Modifying input dictionary so that input_dict['data'].inputs is a list # of input vertices for ConcatenateN2A class _DummyVertex: inputs = [v for v in input_dict.values()] if isinstance(quant_attrs, attributes.ConcatQuantAttrs): return ConcatenateN2A.gen_vertex( output_type, quant_attrs.attrs, quant_attrs, {"data": _DummyVertex()}, data_layout, name, rounding_type=rounding_type) elif isinstance(attrs, attributes.ConcatenateAttrs): return ConcatenateN2A.gen_vertex( output_type, attrs, None, {"data": _DummyVertex()}, data_layout, name, rounding_type=rounding_type) else: assert isinstance(attrs, attributes.TupleConcatenateAttrs) return ConcatenateN2A.gen_vertex( output_type, attrs.concat_attrs, None, {"data": _DummyVertex()}, data_layout, name, rounding_type=rounding_type)
[docs] class MultiplyN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: # Get input vertices lhs_vertex = input_dict["lhs"] rhs_vertex = input_dict["rhs"] is_quantized = quant_attrs is not None if is_quantized: assert isinstance(quant_attrs, attributes.MultiplyQuantAttrs) zp_lhs = quant_attrs.lhs_zero_point zp_rhs = quant_attrs.rhs_zero_point requant = quant_attrs.requant intrinsic_shift = quant_attrs.intrinsic_shift else: zp_lhs = zp_rhs = 0 requant = get_id_requantization(bfloat16) intrinsic_shift = 0 return n2a_compiler_utils.gen_compiler_mul_vertex( name, lhs_vertex, rhs_vertex, zp_lhs, zp_rhs, requant, intrinsic_shift)
[docs] class MaxPoolN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: # Get input vertices data_vertex = input_dict['data'] # Get MaxPool attributes is_quantized = quant_attrs is not None if is_quantized: assert isinstance(quant_attrs, attributes.PoolQuantAttrs) pool_attrs = quant_attrs.pool_attrs requant = quant_attrs.requant else: assert isinstance(attrs, attributes.MaxPoolAttrs) pool_attrs = attrs requant = None pool_size = tuple(pool_attrs.pool_size[1:-1]) strides = tuple(pool_attrs.strides[1:-1]) padding = sum(pool_attrs.padding[1:-1], start=tuple()) if len(pool_attrs.pool_size) == 4: pool_size = (1, *pool_size) strides = (1, *strides) padding = (0, 0, *padding) output_type = get_expected_tensor_value(output_type) output_dtype = n2a_compiler_utils.fix_dtype(output_type.scalar.numpy_type()) output_shape = n2a_compiler_utils.shape_to_compiler_data_layout( output_type.shape, current_layout=data_layout ) # Generate vertex and return it return n2a_compiler_utils.gen_compiler_maxpool_vertex( name, output_shape, data_vertex, pool_size, strides, padding, output_dtype, requant=requant)
[docs] class AvgPoolN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: # Get input vertices data_vertex = input_dict['data'] # Get AvgPool attributes is_quantized = quant_attrs is not None if is_quantized: assert isinstance(quant_attrs, attributes.PoolQuantAttrs) pool_attrs = quant_attrs.pool_attrs override_rounding_type = quant_attrs.rounding_type requant = quant_attrs.requant pad_value = quant_attrs.pad_value else: assert isinstance(attrs, attributes.AvgPoolAttrs) pool_attrs = attrs override_rounding_type = RoundType.TOEVEN requant = None pad_value = 0 pool_size = tuple(pool_attrs.pool_size[1:-1]) strides = tuple(pool_attrs.strides[1:-1]) padding = sum(pool_attrs.padding[1:-1], start=tuple()) if len(pool_attrs.pool_size) == 4: pool_size = (1, *pool_size) strides = (1, *strides) padding = (0, 0, *padding) output_type = get_expected_tensor_value(output_type) output_dtype = n2a_compiler_utils.fix_dtype(output_type.scalar.numpy_type()) output_shape = n2a_compiler_utils.shape_to_compiler_data_layout( output_type.shape, current_layout=data_layout ) # Generate vertex and return it input_shape = data_vertex.operator.shape.shape op = "global" if input_shape[:-1] == pool_size else "average" return n2a_compiler_utils.gen_compiler_avgpool_vertex( name, output_shape, data_vertex, op=op, pool_size=pool_size, strides=strides, padding=padding, rounding_type=override_rounding_type, output_dtype=output_dtype, requant=requant, pad_value=pad_value )
[docs] class VarianceN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: # Get input vertices data_vertex = input_dict[InputName('data')] mean_vertex = input_dict[InputName('mean')] is_quantized = quant_attrs is not None if is_quantized: assert isinstance(quant_attrs, attributes.VarianceQuantAttrs) requant = quant_attrs.requant requant_var = quant_attrs.requant_var else: assert isinstance(attrs, attributes.VarianceAttrs) divisor = np.float32(1.0 / prod(attrs.input_data_shape[1:-1])) requant = Renormalization( divisor, utils.create_and_verify_narrowing(0, RoundType.TOEVEN, bfloat16) ) requant_var = None return n2a_compiler_utils.gen_compiler_variance_vertex( name, data_vertex, mean_vertex, requant, requant_var )
[docs] class MeanN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: """ Limitations: * input shape is in the NDHWC format, where N=1. * mean is only supported along D, H and/or W axes. * mean does not reduce tensor dimensions, output is NDHWC. Implementation of mean relies on avgpool. """ # Get input vertices data_vertex = input_dict['data'] input_shape = data_vertex.operator.output_shape if quant_attrs is not None: assert isinstance(quant_attrs, attributes.MeanQuantAttrs) attrs = quant_attrs.attrs assert isinstance(attrs, attributes.MeanAttrs) assert attrs.keepdims, "MLA does not support MeanOp with keepdims=False." assert not attrs.exclude, "Axes are not simplified by SimplifyAxisExcludeAttr." assert 0 not in attrs.axis, "MLA does not support MeanOp on batch dimension." assert len(attrs.shape) - 1 not in attrs.axis, ( "MLA does not support MeanOp on channel dimension." ) # Convert axes to the pool size. pool_size = tuple( x if i in attrs.axis else 1 for i, x in enumerate(attrs.shape[1:-1], start=1) ) if len(attrs.shape) == 4: pool_size = (1, *pool_size) pool_op = "global" if pool_size == input_shape[:-1] else "average" padding = (0, ) * len(pool_size) * 2 output_shape = n2a_compiler_utils.shape_to_compiler_data_layout( get_expected_tensor_value(output_type).shape, current_layout=data_layout ) # Generate vertex and return it # Hardcode the rounding_type because the it is hardcoded in the # operation_function.py::mean return n2a_compiler_utils.gen_compiler_avgpool_vertex( name, output_shape, data_vertex, op=pool_op, pool_size=pool_size, strides=pool_size, padding=padding, rounding_type=RoundType.TRUNC )
[docs] class ReluN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: # Get input vertices data_vertex = input_dict["data"] # Get zero point is_quantized = quant_attrs is not None if is_quantized: assert isinstance(quant_attrs, attributes.ReluQuantAttrs) node_zp = quant_attrs.zero_point else: node_zp = 0 return n2a_compiler_utils.gen_compiler_relu_vertex( name, data_vertex, node_zp)
[docs] class ClipN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: # Get input vertices data_vertex = input_dict["data"] # Get attributes _attrs = attrs if attrs is not None else quant_attrs assert isinstance(_attrs, (attributes.ClipAttrs, attributes.ClipQuantAttrs)) return n2a_compiler_utils.gen_compiler_clip_vertex(name, data_vertex, _attrs.a_min, _attrs.a_max)
[docs] class LeakyReluN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: # Get input vertices data_vertex = input_dict["data"] is_quantized = quant_attrs is not None if is_quantized: assert isinstance(quant_attrs, attributes.LeakyReluCompositeQuantAttrs) if quant_attrs.leaky_relu_uses_udf: # Use UDF version. assert quant_attrs.udf_quant_attrs is not None lut_array = quant_attrs.udf_quant_attrs.lookup_table lut_vertex = n2a_compiler_utils.gen_compiler_constant_vertex( name + '_lut', lut_array) return n2a_compiler_utils.gen_compiler_udf_vertex( name, data_vertex, lut_vertex) else: # Use breakdown version. assert quant_attrs.leaky_relu_quant_attrs is not None node_zp = quant_attrs.leaky_relu_quant_attrs.zero_point alpha = quant_attrs.leaky_relu_quant_attrs.alpha right_shift = quant_attrs.leaky_relu_quant_attrs.right_shift return n2a_compiler_utils.gen_compiler_leaky_relu_vertex( name, data_vertex, alpha, node_zp, right_shift, rounding_type) else: assert isinstance(attrs, attributes.LeakyReluAttrs) alpha = attrs.alpha node_zp = right_shift = 0 return n2a_compiler_utils.gen_compiler_leaky_relu_vertex( name, data_vertex, alpha, node_zp, right_shift, RoundType.TOEVEN)
[docs] class LRNN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: assert isinstance(quant_attrs, attributes.LRNQuantAttrs) assert quant_attrs.axis == data_layout.index("C") # Get input vertices data_vertex = input_dict["data"] lut_vertex = n2a_compiler_utils.gen_compiler_constant_vertex( name + '_lut', quant_attrs.lookup_table.reshape(16, 16)) # Hardcode the rounding_type because it is hardcoded in the # operation_function.py::lrn return n2a_compiler_utils.gen_compiler_lrn_vertex( name, data_vertex, lut_vertex, quant_attrs.size, quant_attrs.input_zp, quant_attrs.lut_scale, quant_attrs.lut_zp_corr, quant_attrs.lut_sh, quant_attrs.output_scale, quant_attrs.output_zp_corr, quant_attrs.output_sh, RoundType.UPWARD)
[docs] class UpsamplingN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: # Get input vertices data_vertex = input_dict["data"] is_quantized = quant_attrs is not None if is_quantized: assert isinstance(quant_attrs, attributes.UpsamplingQuantAttrs) _attrs = quant_attrs.upsampling_attrs zp = quant_attrs.input_zp override_rounding_type = quant_attrs.rounding_type else: assert isinstance(attrs, attributes.UpsamplingAttrs) _attrs = attrs zp = 0 override_rounding_type = RoundType.TOEVEN input_shape: tuple[int, ...] = data_vertex.operator.shape.shape target_spatial_shape: tuple[int, ...] = ( input_shape[0], input_shape[1] * _attrs.scale_h, input_shape[2] * _attrs.scale_w ) return n2a_compiler_utils.generate_resize_vertex( data_vertex, _attrs.method, target_spatial_shape, zp, override_rounding_type, mode="align_corners" if _attrs.align_corners else "half_pixel", name=name, )
[docs] class ImageResize2DN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: # Get input vertices data_vertex = input_dict["data"] is_quantized = quant_attrs is not None if is_quantized: assert isinstance(quant_attrs, attributes.ImageResize2DQuantAttrs) image_resize2d_attrs = quant_attrs.image_resize2d_attrs zp = quant_attrs.input_zp override_rounding_type = quant_attrs.rounding_type else: image_resize2d_attrs = attrs zp = 0 override_rounding_type = RoundType.TOEVEN assert isinstance(image_resize2d_attrs, attributes.ImageResize2DAttrs) target_spatial_shape: tuple[int, ...] = ( data_vertex.operator.output_shape[0], image_resize2d_attrs.size[0], image_resize2d_attrs.size[1] ) mode: str = image_resize2d_attrs.coordinate_transformation_mode tf_ver = 2 if mode in ['half_pixel', 'pytorch_half_pixel', 'align_corners', 'asymmetric'] else 1 return n2a_compiler_utils.generate_resize_vertex( data_vertex, image_resize2d_attrs.method, target_spatial_shape, zp, override_rounding_type, tf_ver, mode, name )
################################ # Convolution, Add, Activations ################################
[docs] class ConvN2ACompiler(AfeToN2AOperationConverter): """ Converter for all supported variants of composite 2D convolution, both regular and transposed. """ @staticmethod def _gen_vertex( output_type: DataValue[TensorType], attrs: attributes.ConvQuantAttrs | attributes.ConvAddActivationAttrs, name: str, data_vertex: CompilerVertex, weight_data: np.ndarray, bias_data: Optional[np.ndarray] = None) -> CompilerVertex: # Get Convolution attributes conv_attrs: attributes.ConvAttrs = attrs.conv_attrs match conv_attrs.num_spatial_dimensions: case 2: data_layout = "NHWC" stride = (1, *conv_attrs.stride) dilation = (1, *conv_attrs.dilation) padding = sum(conv_attrs.padding, start=(0, 0)) case 3: data_layout = "NDHWC" stride = conv_attrs.stride dilation = conv_attrs.dilation padding = sum(conv_attrs.padding, start=tuple()) case _ as unreachable: assert_never(unreachable) output_type = get_expected_tensor_value(output_type) output_dtype = n2a_compiler_utils.fix_dtype(output_type.scalar.numpy_type()) output_shape = n2a_compiler_utils.shape_to_compiler_data_layout( output_type.shape, current_layout=data_layout ) # Check if it's a transposed convolution is_transposed = conv_attrs.is_transposed # Check if it is a depthwise convolution with channel multiplier == 1 is_depthwise = conv_attrs.is_depthwise_one_channel if is_depthwise and is_transposed: # Depthwise transposed convolution will be implemented using regular convolution. # Flip the weight tensor to make the equivalent weights for regular convolution. weight_data = np.flip(weight_data, axis=tuple(range(conv_attrs.num_spatial_dimensions))) weight_data = n2a_compiler_utils.to_conv2d_weights_layout(weight_data) # Create weights and bias constant vertices weights_vertex = n2a_compiler_utils.gen_compiler_weight_vertex(name, weight_data) if bias_data is not None: bias_data = n2a_compiler_utils.cast_bias(bias_data) bias_vertex = n2a_compiler_utils.gen_compiler_bias_vertex(name, bias_data) else: bias_vertex = None # Get the input zp and requant param is_quantized = isinstance(attrs, attributes.ConvQuantAttrs) # Whether bfloat16 with int weights quant scheme is used is_bfloat16_with_int_weights = ( weight_data.dtype in QuantizedWeightDtypes and output_type.scalar == ScalarType.bfloat16 ) if is_quantized and not is_bfloat16_with_int_weights: assert isinstance(attrs.requant, (ArithFoldedRequantization, TFLiteRequantization)) requant = afe.ir.quantization_utils.fix_requantization(attrs.requant) input_zp = attrs.input_zp output_zp = attrs.zero_point msb_left_shift = attrs.msb_left_shift else: if is_bfloat16_with_int_weights: assert isinstance(attrs, attributes.ConvQuantAttrs) requant = attrs.requant else: requant = get_id_requantization(output_dtype) input_zp = output_zp = 0 msb_left_shift = False clip_range = None match attrs.activ_attrs: case attributes.ReluAttrs() | attributes.ReluQuantAttrs(): # Relu activation will be fused with convolution. fused_activation = Activation.RELU case attributes.ClipAttrs() | attributes.ClipQuantAttrs(): # Clip activation will be fused with convolution. fused_activation = Activation.CLIP clip_range = attrs.activ_attrs.a_min, attrs.activ_attrs.a_max case None: fused_activation = Activation.NONE case _: raise ValueError("Unhandled activation type") gen_vertex_func: n2a_compiler_utils.ConvVertexProtocol if is_transposed: gen_vertex_func = n2a_compiler_utils.gen_compiler_conv2d_transpose_vertex else: gen_vertex_func = n2a_compiler_utils.gen_compiler_conv2d_vertex conv_vertex = gen_vertex_func( name, output_shape, data_vertex, weights_vertex, bias_vertex, input_zp, output_zp, stride, padding, dilation, msb_left_shift=msb_left_shift, requant=requant, activ=fused_activation, is_depthwise=is_depthwise, groups=conv_attrs.groups, clip_range=clip_range ) return conv_vertex @staticmethod def _check_composite_operator_support( attrs: attributes.ConvAddActivationAttrs | attributes.ConvQuantAttrs): """ Checks if the composite convolution operator is supported by N2A compiler. Currently, supported composite convolution operators are: Conv + Add Conv + Add + Relu ConvTransposed + Add ConvTransposed + Add + Relu :param attrs: Pre-quantization attributes or quantization attributes for the given operation :return: None. Raises an exception if the composite operation is not supported """ _operator_is_supported = isinstance(attrs.activ_attrs, (attributes.ReluAttrs, attributes.ReluQuantAttrs, attributes.ClipAttrs, attributes.ClipQuantAttrs, NoneType)) if not _operator_is_supported: raise NotImplementedError(f"N2A Compiler does not support composite convolution operation with " f"{type(attrs.activ_attrs)} activation attributes.") @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: # Obtain weight and bias data is_quantized = quant_attrs is not None if is_quantized: assert isinstance(quant_attrs, attributes.ConvQuantAttrs) weight_data = quant_attrs.weight_quant_data bias_data = quant_attrs.bias_quant_data cur_attrs = quant_attrs else: assert isinstance(attrs, attributes.ConvAddActivationAttrs) weight_data = attrs.weights_attrs.data bias_data = attrs.bias_attrs.data if attrs.bias_attrs else None cur_attrs = attrs # Check if composite operator is supported ConvN2ACompiler._check_composite_operator_support(cur_attrs) # Get input vertices data_vertex = input_dict['data'] return ConvN2ACompiler._gen_vertex( output_type, cur_attrs, name, data_vertex, weight_data, bias_data)
################################ # Add, Activations ################################
[docs] class AddActivationN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: is_quantized = quant_attrs is not None if is_quantized: assert isinstance(quant_attrs, attributes.AddQuantAttrs) activ_attrs = quant_attrs.activ_attrs lhs_scale = quant_attrs.lhs_scale rhs_scale = quant_attrs.rhs_scale requant = quant_attrs.requant else: assert isinstance(attrs, attributes.AddActivationAttrs) activ_attrs = attrs.activ_attrs lhs_scale = rhs_scale = 1.0 requant = get_id_requantization(bfloat16) clip_range = None if isinstance(activ_attrs, attributes.ReluAttrs | attributes.ReluQuantAttrs): activation = Activation.RELU elif isinstance(activ_attrs, attributes.ClipAttrs | attributes.ClipQuantAttrs): activation = Activation.CLIP clip_range = activ_attrs.a_min, activ_attrs.a_max else: activation = Activation.NONE # Get input vertices lhs_vertex = input_dict["lhs"] rhs_vertex = input_dict["rhs"] add_vertex = n2a_compiler_utils.gen_compiler_add_subtract_vertex( name, lhs_vertex, rhs_vertex, lhs_scale, rhs_scale, requant, "add", activ=activation, clip_range=clip_range) return add_vertex
[docs] class ConstantMultiplyAddN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: if quant_attrs is not None: return AddActivationN2ACompiler.gen_vertex( output_type, attrs, quant_attrs, input_dict, data_layout, name, rounding_type ) # For bfloat 16 decompose constant_multiply_add to multiply and add. assert isinstance(attrs, attributes.ConstantMultiplyAddAttrs) # Get input vertices lhs_vertex = input_dict["lhs"] rhs_vertex = input_dict["rhs"] lhs_const_vertex = n2a_compiler_utils.gen_compiler_constant_vertex( f"{name}/lhs_const", attrs.in1_const_attrs.data ) lhs_multiply_attrs = attributes.MultiplyAttrs( attrs.scalar_type, attrs.lhs_input_shape, attrs.in1_const_attrs.data.shape ) lhs_multiply_input_dict = {"lhs": lhs_vertex, "rhs": lhs_const_vertex} lhs_vertex = MultiplyN2ACompiler.gen_vertex( output_type, lhs_multiply_attrs, None, lhs_multiply_input_dict, data_layout, f"{name}/multiply_lhs", rounding_type ) if attrs.in2_const_attrs: rhs_const_vertex = n2a_compiler_utils.gen_compiler_constant_vertex( f"{name}/rhs_const", attrs.in2_const_attrs.data ) rhs_multiply_attrs = attributes.MultiplyAttrs( attrs.scalar_type, attrs.lhs_input_shape, attrs.in1_const_attrs.data.shape ) rhs_multiply_input_dict = {"lhs": rhs_vertex, "rhs": rhs_const_vertex} rhs_vertex = MultiplyN2ACompiler.gen_vertex( output_type, rhs_multiply_attrs, None, rhs_multiply_input_dict, data_layout, f"{name}/multiply_rhs", rounding_type ) lhs_scale = rhs_scale = 1.0 requant = get_id_requantization(bfloat16) add_vertex = n2a_compiler_utils.gen_compiler_add_subtract_vertex( name, lhs_vertex, rhs_vertex, lhs_scale, rhs_scale, requant,"add" ) return add_vertex
[docs] class SubtractN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: # Get input vertices lhs_vertex = input_dict["lhs"] rhs_vertex = input_dict["rhs"] is_quantized = quant_attrs is not None if is_quantized: assert isinstance(quant_attrs, attributes.SubtractQuantAttrs) # Get Subtract attributes in1_scale = quant_attrs.lhs_scale in2_scale = quant_attrs.rhs_scale requant = quant_attrs.requant else: in1_scale = in2_scale = 1 requant = get_id_requantization(bfloat16) return n2a_compiler_utils.gen_compiler_add_subtract_vertex( name, lhs_vertex, rhs_vertex, in1_scale, in2_scale, requant, op='sub')
[docs] class ArgMaxN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: input_vertex = input_dict["data"] if quant_attrs is None: select_last_index = attrs.select_last_index else: select_last_index = quant_attrs.attrs.select_last_index return n2a_compiler_utils.gen_compiler_arg_min_max_vertex( name, input_vertex, True, select_last_index )
[docs] class UdfN2ACompiler(AfeToN2AOperationConverter): @staticmethod def _gen_vertex(quant_attrs: attributes.UDFQuantAttrs, input_dict: Dict[InputName, Any], name: str) \ -> CompilerVertex: data_vertex = input_dict['data'] lut_vertex = n2a_compiler_utils.gen_compiler_constant_vertex( name + '_lut', quant_attrs.lookup_table) return n2a_compiler_utils.gen_compiler_udf_vertex(name, data_vertex, lut_vertex) @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: if quant_attrs is None and output_type.value.scalar == ScalarType.bfloat16: raise UserFacingException( f"Layer {name} does not have bfloat16 support. Consider using int8 or int16" f"quantization mode." ) assert isinstance(quant_attrs, attributes.UDFQuantAttrs) return UdfN2ACompiler._gen_vertex(quant_attrs, input_dict, name)
[docs] class ErfN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: if quant_attrs is not None: return UdfN2ACompiler.gen_vertex(output_type, attrs, quant_attrs, input_dict, data_layout, name, rounding_type) else: assert attrs is not None data_vertex = input_dict['data'] return n2a_compiler_utils.gen_compiler_erf_vertex(data_vertex, name)
[docs] class SliceN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: data_vertex = input_dict['data'] assert isinstance(attrs, attributes.StridedSliceAttrs) begin, end, strides = operations.expand_indices_to_shape_length( begin=attrs.begin, end=attrs.end, strides=attrs.strides, axes=attrs.axes, input_shape=list(attrs.input_shape)) size = get_expected_tensor_value(output_type).shape if len(size) == 4: # NHWC -> 1HWC. size = (1, *size[1:]) begin = (0, *begin[1:]) strides = (1, *strides[1:]) else: # NDHWC -> DHWC. size = output_shape[1:] begin = begin[1:] strides = strides[1:] return n2a_compiler_utils.gen_compiler_slice_vertex( name, data_vertex, begin=begin, size=size, stride=strides )
[docs] class SoftmaxN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: if quant_attrs is None: assert isinstance(attrs, attributes.SoftmaxAttrs) non_neg_axis = convert_axis_to_non_negative(attrs.axis, len(attrs.input_shape), False) lut_exp_data = compute_bf16_exp_lut() lut_rec_data = compute_bf16_reciprocal_lut() exp_zp = rec_zp = lut_input_pre_shift = output_pre_shift = None requant_lut = requant_output = get_id_requantization(bfloat16) else: assert isinstance(quant_attrs, attributes.SoftmaxQuantAttrs) non_neg_axis = convert_axis_to_non_negative(quant_attrs.axis, len(quant_attrs.input_shape), False) lut_exp_data = quant_attrs.lookup_table_exp lut_rec_data = quant_attrs.lookup_table_rec exp_zp = quant_attrs.exp_zp rec_zp = quant_attrs.rec_zp requant_lut = quant_attrs.requant_lut requant_output = quant_attrs.requant_output lut_input_pre_shift=quant_attrs.lut_input_pre_shift output_pre_shift=quant_attrs.output_pre_shift assert non_neg_axis == data_layout.index("C"), \ f"Operator's axis ({non_neg_axis}) != channel axis ({data_layout.index('C')})" axis = n2a_compiler_utils.axis_to_compiler_data_layout(non_neg_axis, data_layout) # Get input vertices data_vertex = input_dict["data"] lut_exp_vertex = n2a_compiler_utils.gen_compiler_constant_vertex( f"{name}_exp_lut", lut_exp_data ) lut_rec_vertex = n2a_compiler_utils.gen_compiler_constant_vertex( f"{name}_rec_lut", lut_rec_data ) return n2a_compiler_utils.gen_compiler_softmax_vertex( name, data_vertex=data_vertex, lut_exp_vertex=lut_exp_vertex, lut_rec_vertex=lut_rec_vertex, axis=axis, exp_zp=exp_zp, rec_zp=rec_zp, requant_lut=requant_lut, requant_output=requant_output, lut_input_pre_shift=lut_input_pre_shift, output_pre_shift=output_pre_shift )
[docs] class LayerNormN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: if quant_attrs is None: assert isinstance(attrs, attributes.LayerNormAttrs) axis = attrs.axis epsilon = attrs.epsilon lut_vertex = n2a_compiler_utils.gen_compiler_constant_vertex( f"{name}_rsqrt_lut", compute_bf16_rsqrt_lut() ) input_shape = attrs.input_shape zp_rsqrt = None requant_mean = None requant_lut_input = None requant_output = None else: assert isinstance(quant_attrs, attributes.LayerNormQuantAttrs) axis = quant_attrs.axis epsilon = None input_shape = quant_attrs.input_shape lut_vertex = n2a_compiler_utils.gen_compiler_constant_vertex( name + '_rsqrt_lut', quant_attrs.lookup_table_rsqrt.reshape(16, 16)) zp_rsqrt = quant_attrs.zp_rsqrt requant_mean = quant_attrs.requant_mean requant_lut_input = quant_attrs.requant_lut_input requant_output = quant_attrs.requant_output non_neg_axis = convert_axis_to_non_negative(axis, len(input_shape), False) assert non_neg_axis == data_layout.index("C"), \ f"Operator's axis ({non_neg_axis}) != channel axis ({data_layout.index('C')})" axis_compiler_data_layout = n2a_compiler_utils.axis_to_compiler_data_layout(non_neg_axis, data_layout) # Get input vertices. data_vertex = input_dict["data"] return n2a_compiler_utils.gen_compiler_layer_norm_vertex( name, data_vertex=data_vertex, lut_vertex=lut_vertex, axis=axis_compiler_data_layout, epsilon=epsilon, rsqrt_zp=zp_rsqrt, requant_mean=requant_mean, requant_lut_input=requant_lut_input, requant_output=requant_output)
[docs] class RMSNormN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: # Get input vertices. data_vertex = input_dict["data"] axis = n2a_compiler_utils.axis_to_compiler_data_layout(data_layout.index("C"), data_layout) if quant_attrs is None: assert isinstance(attrs, attributes.RMSNormAttrs) lut_data = compute_bf16_rsqrt_lut() lut_vertex = n2a_compiler_utils.gen_compiler_constant_vertex( f"{name}_rsqrt_lut", compute_bf16_rsqrt_lut() ) return gen_compiler_rms_norm_vertex( name, data_vertex, lut_vertex, axis, epsilon=attrs.epsilon ) else: assert isinstance(quant_attrs, attributes.RMSNormQuantAttrs) lut_vertex = n2a_compiler_utils.gen_compiler_constant_vertex( f"{name}_rsqrt_lut", quant_attrs.lookup_table_rsqrt ) return gen_compiler_rms_norm_vertex( name, data_vertex, lut_vertex, axis, zp_ifm=quant_attrs.zp_ifm, zp_rsqrt=quant_attrs.zp_rsqrt, requant_lut=quant_attrs.requant_lut_input, requant_output=quant_attrs.requant_output, lut_input_pre_shift=quant_attrs.lut_input_pre_shift, output_pre_shift=quant_attrs.output_pre_shift, epsilon=None )
[docs] class InstanceNormN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: # Get input vertices. data_vertex = input_dict[InputName("data")] mean_vertex = input_dict[InputName("mean")] variance_vertex = input_dict[InputName("variance")] is_quantized = quant_attrs is not None if is_quantized: assert isinstance(quant_attrs, attributes.InstanceNormQuantAttrs) lut_vertex = n2a_compiler_utils.gen_compiler_constant_vertex( f"{name}_rsqrt_lut", quant_attrs.lut_rsqrt ) requant = quant_attrs.requant_out zp_rsqrt = quant_attrs.zp_rsqrt epsilon = None else: assert isinstance(attrs, attributes.InstanceNormAttrs) lut_vertex = n2a_compiler_utils.gen_compiler_constant_vertex( f"{name}_rsqrt_lut", compute_bf16_rsqrt_lut() ) requant = None zp_rsqrt = None epsilon = attrs.epsilon return n2a_compiler_utils.gen_compiler_instance_norm_vertex(name, data_vertex, mean_vertex, variance_vertex, lut_vertex, zp_rsqrt, requant, epsilon)
[docs] class GridSampleN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: data_vertex = input_dict["data"] grid_vertex = input_dict["grid"] assert quant_attrs is None and isinstance(attrs, attributes.GridSampleAttrs), \ "Only bfloat16 is supported for GridSample" padding_mode = attrs.padding_mode align_corners = attrs.align_corners return n2a_compiler_utils.gen_compiler_grid_sample_vertex( name, data_vertex, grid_vertex, padding_mode, align_corners )
[docs] class BatchMatmulN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: assert isinstance(quant_attrs, (attributes.BatchMatmulQuantAttrs, NoneType)) # Get input vertices lhs_vertex = input_dict["lhs"] rhs_vertex = input_dict["rhs"] is_quantized = quant_attrs is not None if is_quantized: transpose_b = quant_attrs.attrs.transpose_b input_zps = [quant_attrs.lhs_zp, quant_attrs.rhs_zp] requant = quant_attrs.requant intrinsic_shift = quant_attrs.intrinsic_shift else: assert isinstance(attrs, attributes.BatchMatmulAttrs) transpose_b = attrs.transpose_b input_zps = [0, 0] requant = get_id_requantization("bfloat16") intrinsic_shift = 0 return n2a_compiler_utils.gen_compiler_batch_matmul_vertex( name, lhs=lhs_vertex, rhs=rhs_vertex, transpose_b=transpose_b, input_zps=input_zps, requant=requant, intrinsic_shift=intrinsic_shift )
[docs] class UnaryBatchMatmulN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: Optional[attributes.AwesomeAttributes], quant_attrs: Optional[attributes.AwesomeQuantAttrBase], input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) -> CompilerVertex: assert isinstance(quant_attrs, attributes.BatchMatmulQuantAttrs) # Get input vertices data_vertex = input_dict["data"] return n2a_compiler_utils.gen_compiler_batch_matmul_vertex( name, lhs=data_vertex, rhs=data_vertex, transpose_b=quant_attrs.attrs.transpose_b, input_zps=[quant_attrs.lhs_zp, quant_attrs.rhs_zp], requant=quant_attrs.requant, intrinsic_shift=quant_attrs.intrinsic_shift )
[docs] class SliceConcatN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: attributes.AwesomeAttributes, quant_attrs: attributes.AwesomeQuantAttrBase, input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) \ -> CompilerVertex: if quant_attrs is None: assert isinstance(attrs, attributes.SliceConcatAttrs) slice_attrs = attrs.slice_attrs concat_attrs = attrs.tuple_concat_attrs.concat_attrs requants = [get_id_requantization(bfloat16)] else: assert isinstance(quant_attrs, attributes.SliceConcatQuantAttrs) slice_attrs = quant_attrs.slice_attrs tuple_concat_attrs = quant_attrs.tuple_concat_attrs concat_attrs = tuple_concat_attrs.attrs requants = quant_attrs.tuple_concat_attrs.requants split_axis = n2a_compiler_utils.axis_to_compiler_data_layout( slice_attrs[0].axes[0], AwesomeDataLayout ) split_block = len(slice_attrs) concat_axis = n2a_compiler_utils.axis_to_compiler_data_layout( concat_attrs.axis, AwesomeDataLayout ) input_vertex = input_dict["data"] return n2a_compiler_utils.gen_compiler_concat_vertex( name, [input_vertex], requants=requants, axis=concat_axis, split_axis=split_axis, split_block=split_block )
[docs] class BroadcastToN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: attributes.AwesomeAttributes, quant_attrs: attributes.AwesomeQuantAttrBase, input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) \ -> CompilerVertex: if quant_attrs is None: output_shape = attrs.output_shape else: assert isinstance(quant_attrs, attributes.BroadcastToQuantAttrs) output_shape = quant_attrs.output_shape if len(output_shape) == 5: output_shape = output_shape[1:] else: assert len(output_shape) == 4 and output_shape[0] == 1 input_vertex = input_dict["data"] return n2a_compiler_utils.gen_compiler_broadcast_to_vertex(name, input_vertex, output_shape)
[docs] class RequantizeN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: attributes.AwesomeAttributes, quant_attrs: attributes.AwesomeQuantAttrBase, input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) \ -> CompilerVertex: assert isinstance(quant_attrs, attributes.RequantizeQuantAttrs) requant = quant_attrs.requant input_vertex = input_dict["data"] return n2a_compiler_utils.gen_compiler_requantization_vertex(name, input_vertex, requant)
[docs] class DequantizationN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: attributes.AwesomeAttributes, quant_attrs: attributes.AwesomeQuantAttrBase, input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) \ -> CompilerVertex: assert isinstance(attrs, attributes.DequantizationTransformAttrs) assert len(attrs.channel_params) == 1 scale, zp = attrs.channel_params[0] requant = FloatRequantization( sc_correction=np.float32(scale), zp_correction=zp, out_dtype=attrs.output_type.numpy_type() ) input_vertex = input_dict["data"] return n2a_compiler_utils.gen_compiler_requantization_vertex(name, input_vertex, requant)
[docs] class QuantizationN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: attributes.AwesomeAttributes, quant_attrs: attributes.AwesomeQuantAttrBase, input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) \ -> CompilerVertex: assert len(attrs.channel_params) == 1 scale, zp = attrs.channel_params[0] requant = FloatRequantization( sc_correction=np.float32(scale), zp_correction=zp, out_dtype=attrs.output_data_type.numpy_type() ) input_vertex = input_dict["data"] return n2a_compiler_utils.gen_compiler_requantization_vertex(name, input_vertex, requant)
[docs] class TransposeN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: attributes.AwesomeAttributes, quant_attrs: attributes.AwesomeQuantAttrBase, input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) \ -> CompilerVertex: assert isinstance(attrs, attributes.TransposeAttrs) # Axes to HWC if len(attrs.axes) == 4: new_axes = (0, *attrs.axes[1:]) else: assert len(attrs.axes) == 5 new_axes = tuple([axis - 1 for axis in attrs.axes[1:]]) input_vertex = input_dict["data"] return n2a_compiler_utils.gen_compiler_transpose_vertex(name, input_vertex, perm=new_axes)
[docs] class DepthToSpaceN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex( output_type: DataValue[TensorType], attrs: attributes.AwesomeAttributes, quant_attrs: attributes.AwesomeQuantAttrBase, input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD ) -> CompilerVertex: assert isinstance(attrs, attributes.DepthToSpaceAttrs) input_vertex = input_dict["data"] return n2a_compiler_utils.gen_compiler_depth_to_space_vertex(name, input_vertex, block_size=attrs.block_size, mode=attrs.mode)
[docs] class DivideN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex(output_type: DataValue[TensorType], attrs: attributes.AwesomeAttributes, quant_attrs: attributes.AwesomeQuantAttrBase, input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD) \ -> CompilerVertex: assert isinstance(quant_attrs, attributes.DivideQuantAttrs) lhs_vertex = input_dict["lhs"] rhs_vertex = input_dict["rhs"] reciprocal_vertex = UdfN2ACompiler.gen_vertex( output_type, None, quant_attrs.udf_attrs, {'data': rhs_vertex}, data_layout, f'{name}_rhs_reciprocal', rounding_type ) return MultiplyN2ACompiler.gen_vertex( output_type, None, quant_attrs.multiply_attrs, {'lhs': lhs_vertex, 'rhs': reciprocal_vertex}, data_layout, name, rounding_type )
[docs] class SigmoidN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex( output_type: DataValue[TensorType], attrs: attributes.AwesomeAttributes, quant_attrs: attributes.AwesomeQuantAttrBase | None, input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD ) -> CompilerVertex: if isinstance(quant_attrs, attributes.UDFQuantAttrs): return UdfN2ACompiler.gen_vertex( output_type, attrs, quant_attrs, input_dict, data_layout, name, rounding_type ) else: assert quant_attrs is None and isinstance(attrs, attributes.UDFAttrs) lut_exp = n2a_compiler_utils.gen_compiler_constant_vertex( f"{name}_lut_exp", compute_bf16_exp_lut() ) lut_rec = n2a_compiler_utils.gen_compiler_constant_vertex( f"{name}_lut_rec", compute_bf16_reciprocal_lut() ) return n2a_compiler_utils.gen_compiler_sigmoid_vertex( name, input_dict["data"], lut_exp, lut_rec )
[docs] class SwishN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex( output_type: DataValue[TensorType], attrs: attributes.AwesomeAttributes, quant_attrs: attributes.AwesomeQuantAttrBase | None, input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD ) -> CompilerVertex: if isinstance(quant_attrs, attributes.UDFQuantAttrs): return UdfN2ACompiler.gen_vertex( output_type, attrs, quant_attrs, input_dict, data_layout, name, rounding_type ) else: assert quant_attrs is None and isinstance(attrs, attributes.UDFAttrs) sigmoid_vertex = SigmoidN2ACompiler.gen_vertex( output_type, attrs, None, input_dict, data_layout, f"{name}/sigmoid", rounding_type ) multiply_attrs = attributes.MultiplyAttrs( attrs.scalar_type, attrs.input_shape, attrs.input_shape ) multiply_input_dict = {"lhs": input_dict["data"], "rhs": sigmoid_vertex} return MultiplyN2ACompiler.gen_vertex( output_type, multiply_attrs, None, multiply_input_dict, data_layout, f"{name}/multiply", rounding_type )
[docs] class TanhN2ACompiler(AfeToN2AOperationConverter): @staticmethod
[docs] def gen_vertex( output_type: DataValue[TensorType], attrs: attributes.AwesomeAttributes, quant_attrs: attributes.AwesomeQuantAttrBase | None, input_dict: Dict[InputName, Any], data_layout: str, name: str, rounding_type: RoundType = RoundType.UPWARD ) -> CompilerVertex: if isinstance(quant_attrs, attributes.UDFQuantAttrs): return UdfN2ACompiler.gen_vertex( output_type, attrs, quant_attrs, input_dict, data_layout, name, rounding_type ) else: # For bfloat16 rewrite tanh(x) as 2*sigmoid(2x) - 1 assert quant_attrs is None and isinstance(attrs, attributes.UDFAttrs) constant_vertex_1 = n2a_compiler_utils.gen_compiler_constant_vertex( f"{name}/const_1", np.array([2]).astype(bfloat16) ) multiply_attrs_1 = attributes.MultiplyAttrs( attrs.scalar_type, attrs.input_shape, (1,) ) multiply_input_dict_1 = {"lhs": input_dict["data"], "rhs": constant_vertex_1} multiply_vertex_1 = MultiplyN2ACompiler.gen_vertex( output_type, multiply_attrs_1, None, multiply_input_dict_1, data_layout, f"{name}/multiply_1", rounding_type ) sigmoid_vertex = SigmoidN2ACompiler.gen_vertex( output_type, attrs, None, {"data": multiply_vertex_1}, data_layout, f"{name}/sigmoid", rounding_type ) constant_vertex_2 = n2a_compiler_utils.gen_compiler_constant_vertex( f"{name}/const_2", np.array([2]).astype(bfloat16) ) multiply_attrs_2 = attributes.MultiplyAttrs( attrs.scalar_type, attrs.input_shape, (1,) ) multiply_input_dict_2 = {"lhs": sigmoid_vertex, "rhs": constant_vertex_2} multiply_vertex_2 = MultiplyN2ACompiler.gen_vertex( output_type, multiply_attrs_2, None, multiply_input_dict_2, data_layout, f"{name}/multiply_2", rounding_type ) constant_vertex_3 = n2a_compiler_utils.gen_compiler_constant_vertex( f"{name}/const_3", np.array([1]).astype(bfloat16) ) subtract_attrs = attributes.SubtractAttrs( attrs.scalar_type, attrs.input_shape, (1,) ) subtract_input_dict = {"lhs": multiply_vertex_2, "rhs": constant_vertex_3} return SubtractN2ACompiler.gen_vertex( output_type, subtract_attrs, None, subtract_input_dict, data_layout, f"{name}/subtract", rounding_type )
# Dictionary used in converting Awesome Operations to Compiler Vertices
[docs] AFE_TO_N2ACOMPILER_OPERATOR_DICT: Dict[Type[operations.AwesomeOperation], Type[AfeToN2AOperationConverter]] = { operations.PlaceholderOp: PlaceholderN2ACompiler, operations.ConstantOp: ConstantN2ACompiler, operations.TupleOp: TupleN2A, operations.ConcatenateOp: ConcatenateN2A, operations.MaxPool2DOp: MaxPoolN2ACompiler, operations.AvgPool2DOp: AvgPoolN2ACompiler, operations.AvgPool3DOp: AvgPoolN2ACompiler, operations.MeanOp: MeanN2ACompiler, operations.LRNOp: LRNN2ACompiler, operations.LeakyReluCompositeOp: LeakyReluN2ACompiler, operations.UpsamplingOp: UpsamplingN2ACompiler, operations.ImageResize2DOp: ImageResize2DN2ACompiler, operations.MultiplyOp: MultiplyN2ACompiler, operations.SqrtOp: UdfN2ACompiler, operations.RsqrtOp: UdfN2ACompiler, operations.TanhOp: TanhN2ACompiler, operations.SigmoidOp: SigmoidN2ACompiler, operations.LogOp: UdfN2ACompiler, operations.Log2Op: UdfN2ACompiler, operations.Log10Op: UdfN2ACompiler, operations.ExpOp: UdfN2ACompiler, operations.SwishOp: SwishN2ACompiler, operations.EluOp: UdfN2ACompiler, operations.PReluOp: PreluN2A, operations.ReluOp: ReluN2ACompiler, operations.SoftplusOp: UdfN2ACompiler, operations.ErfOp: ErfN2ACompiler, operations.HardSwishOp: UdfN2ACompiler, operations.HardSigmoidOp: UdfN2ACompiler, operations.SubtractOp: SubtractN2ACompiler, operations.ArgMaxOp: ArgMaxN2ACompiler, operations.SoftmaxOp: SoftmaxN2ACompiler, operations.RequantizeOp: RequantizeN2ACompiler, operations.TransposeOp: TransposeN2ACompiler, operations.DepthToSpaceOp: DepthToSpaceN2ACompiler, operations.DequantizationTransformOp: DequantizationN2ACompiler, operations.QuantizationTransformOp: QuantizationN2ACompiler, operations.VarianceOp: VarianceN2ACompiler, operations.ClipOp: ClipN2ACompiler, # COMPOSITE OPERATIONS operations.ConvAddActivationOp: ConvN2ACompiler, operations.AddActivationOp: AddActivationN2ACompiler, operations.TupleConcatenateOp: TupleConcatenateN2A, operations.ConstantMultiplyAddOp: ConstantMultiplyAddN2ACompiler, operations.StridedSliceOp: SliceN2ACompiler, operations.LayerNormOp: LayerNormN2ACompiler, operations.BatchMatmulOp: BatchMatmulN2ACompiler, operations.UnaryBatchMatmulOp: UnaryBatchMatmulN2ACompiler, operations.SliceConcatOp: SliceConcatN2ACompiler, operations.BroadcastToOp: BroadcastToN2ACompiler, operations.DivideOp: DivideN2ACompiler, operations.RMSNormOp: RMSNormN2ACompiler, operations.GeluOp: UdfN2ACompiler, operations.GridSampleOp: GridSampleN2ACompiler, operations.InstanceNormOp: InstanceNormN2ACompiler }
[docs] def get_afe_to_n2a_converter(node: AwesomeNode) -> Type[AfeToN2AOperationConverter]: """ Utility function for getting the AFE node to CompilerVertex converter from the converters' dictionary. :param node: AwesomeNode that is to be converted to CompilerVertex. :return: The converter type from the AFE_TO_N2ACOMPILER_OPERATOR_DICT. """ _converters: Dict[Type[operations.AwesomeOperation], Type[AfeToN2AOperationConverter]] = \ AFE_TO_N2ACOMPILER_OPERATOR_DICT # Get converter for the SimaIR operator if type(node.ir.operation) not in _converters.keys(): raise TypeError(f"n2a compiler does not support operation {node.ir.operation.__class__.__name__} " f"in {node.name} node!") return _converters[type(node.ir.operation)]