Source code for afe.tvm_converter.custom_convert_maps

#########################################################
# Copyright (C) 2021 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Joey Chou
#########################################################
"""
TODO(Joey.Chou):
    Use registration method to register
        * custom_convert_map
        * MergeComposite Pattern _PATTERN_TABLE
        * tvm_func_to_awesome_attributes
"""
import copy

import numpy as np
import tvm
from tvm import relay
from tvm.relay.frontend.common import infer_type
from afe.backends import Backend
from afe.ir.attributes import AwesomeAttributes
from afe.ir.tensor_type import ScalarType
import afe.ir.utils as utils
import afe._tvm._defines as tvm_def
from afe._tvm._utils import get_expr_input_shapes, get_expr_output_shapes
from afe.tvm_converter.parameters import TVMConverterParams


[docs] class GroupConv2DTranspose(): """ Models: CenterNet Author: Alicja Kwasniewska, Joey Chou Decompose a nn.conv2d_transpose with groups > 1 into splits, multiple nn.conv2d_transpose, and concatenate as below: | split --------------------------- .... | | | | | ... tuple_get_item ..... | nn.conv2d_transpose | | | | | | ... --------------------------- .... | tuple | concatenate | """
[docs] ONNX_op_name = "ConvTranspose"
@staticmethod
[docs] def map_impl(): def _conv_single_group_rec(data_splitted, weight_splitted, attr, groups, out=None): if groups > 0: groups = groups - 1 data_slice = relay.TupleGetItem(data_splitted, groups) weight_slice = relay.TupleGetItem(weight_splitted, groups) conv = relay.nn.conv2d_transpose(data_slice, weight_slice, strides=attr['strides'], padding=attr['pads'], dilation=attr['dilations'], kernel_size=attr['kernel_shape'], ) if out: out.insert(0, conv) else: out = [] out.append(conv) return _conv_single_group_rec(data_splitted, weight_splitted, attr, groups, out) else: out = relay.concatenate(out, axis=1) return out def _impl(inputs, attr, params): assert 2 <= len(inputs) <= 3, (f"Inputs of ONNX's {GroupConv2DTranspose.ONNX_op_name} operators " f"must have either 2 inputs (data, weights) or 3 inputs (data, weights, " f"bias). Got {len(inputs)} inputs") # default strides if they're not specified in attr data_shape = infer_type(inputs[0]).checked_type.shape weight_shape = infer_type(inputs[1]).checked_type.shape default_strides = (1,) * (len(data_shape) - 2) attr['strides'] = attr['strides'] if 'strides' in attr.keys() else default_strides # In ONNX, the default data dimensions is NCHW data_channel_axis = 1 # In ONNX, the default weight dimensions is OIHW weight_output_channe_axis = 0 groups = attr.get('group', 1) if groups == 1 or utils.is_depthwise_conv(data_shape[data_channel_axis], weight_shape[weight_output_channe_axis], groups): # Use TVM's translation for transposed convolution from tvm.relay.frontend.onnx import ConvTranspose return ConvTranspose.get_converter(opset=17)(inputs, attr, params) # Else, it is transposed convolution with multiple groups. # There is no SiMa IR support for grouped transposed convolution, # so override TVM's frontend to decompose it into multiple transposed convolutions. assert groups > 1 data_splitted = relay.split(inputs[0], groups, axis=data_channel_axis).astuple() weight_splitted = relay.split(inputs[1], groups, axis=weight_output_channe_axis).astuple() out = _conv_single_group_rec(data_splitted, weight_splitted, attr, groups) # If there are 3 inputs, the last input is the bias tensor if len(inputs) == 3: out = relay.nn.bias_add(out, inputs[2]) return out return _impl
@staticmethod
[docs] def tvm_func_to_awesome_attributes(params: TVMConverterParams, func: tvm_def.TVMFunction) -> AwesomeAttributes: """ Functions to extract decomposed group conv2d_transpose Relay attributes and translate it into SiMa IR Conv2DTransposeAttrs """ from afe.ir.attributes import ConstantAttrs, ConvAddActivationAttrs from afe.tvm_converter._attributes import ( get_tvm_call_op_attr_in_dict, get_tvm_constant_attr_in_dict, extract_expressions_from_function_used_in_composite_function_attributes ) from afe.tvm_converter._operators import _make_conv_attributes # Extract all Relay attributes in the composite function composite_fn_expressions = extract_expressions_from_function_used_in_composite_function_attributes(func) # Get the conv2d_transpose_attrs_list from relay.nn.conv2d_transpose conv2d_transpose_expressions = [expr for expr in composite_fn_expressions if isinstance(expr, tvm_def.TVMCall) and expr.op.name == "nn.conv2d_transpose"] layout_transform_expressions = [expr for expr in composite_fn_expressions if isinstance(expr, tvm_def.TVMCall) and expr.op.name == "layout_transform"] weight_data = [get_tvm_constant_attr_in_dict(expr)['data'] for expr in composite_fn_expressions if isinstance(expr, tvm_def.TVMConstant)] conv2d_transpose_attrs_list = \ [get_tvm_call_op_attr_in_dict(expr) for expr in conv2d_transpose_expressions] # Get input shape by transposing to a correct layout and then multiplying the Channels with number of groups. if len(layout_transform_expressions) > 0: current_layout = layout_transform_expressions[0].attrs.src_layout else: current_layout = "NHWC" dst_layout = "NHWC" input_shape = get_expr_input_shapes(conv2d_transpose_expressions[0])[0] input_shape = utils.transpose_attr_according_to_layout_strings(input_shape, dst_layout, current_layout) output_shape = get_expr_output_shapes(conv2d_transpose_expressions[0])[0] output_shape = utils.transpose_attr_according_to_layout_strings(output_shape, dst_layout, current_layout) # Create attributes for the grouped convolution corresponding to this composite operator groups = len(weight_data) conv2d_transpose_attrs = conv2d_transpose_attrs_list[0] grouped_attrs = copy.deepcopy(conv2d_transpose_expressions[0].attrs) grouped_attrs.groups = groups grouped_attrs.channels *= groups grouped_attrs.data_layout = "NHWC" grouped_attrs.out_layout = "NHWC" # Not changed: kernel_size, strides, padding, output_padding, dilation, # kernel_layout, out_dtype # Create final weights_attrs weight_output_channe_axis = conv2d_transpose_attrs['kernel_layout'].index("O") weight = np.concatenate(weight_data, axis=weight_output_channe_axis) # Create SiMa IR attributes input_shape[current_layout.index('C')] *= groups output_shape[current_layout.index('C')] *= groups weight, attrs = _make_conv_attributes(tuple(input_shape), 'float32', tuple(output_shape), weight, grouped_attrs) weights_attrs = ConstantAttrs(data=weight) return ConvAddActivationAttrs(weights_attrs=weights_attrs, conv_attrs=attrs)
[docs] class ConstantMultiplyAdd(): """ Merge multiply with constant to add op. Do it for both inputs if possible. """ @staticmethod
[docs] def tvm_func_to_awesome_attributes(params: TVMConverterParams, func: tvm_def.TVMFunction) -> AwesomeAttributes: """ Functions to extract decomposed constant, mul, add relay attributes from composite constant_multiply_add and translate it into SiMa IR ConstantMultiplyAddAttrs. """ from afe.ir.attributes import ConstantAttrs, ConstantMultiplyAddAttrs, AddActivationAttrs from afe.tvm_converter._attributes import ( get_non_global_var_awesome_attrs, extract_expressions_from_function_used_in_composite_function_attributes ) # Extract all Relay attributes in the composite function composite_fn_expressions = extract_expressions_from_function_used_in_composite_function_attributes(func) # Convert each Relay attribute to its corresponding AwesomeAttributes awesome_attrs = [get_non_global_var_awesome_attrs(expr, params, Backend.MLA)[0] for expr in composite_fn_expressions] in2_const_attrs, in2_mul_attrs = None, None if len(awesome_attrs) == 3: # Only one input of add op is multiply with constant in1_const_attrs, in1_mul_attrs, add_attrs = awesome_attrs elif len(awesome_attrs) == 5: # Both inputs of add op are multiply with constant in1_const_attrs, in1_mul_attrs, in2_const_attrs, in2_mul_attrs, add_attrs = awesome_attrs assert isinstance(in2_const_attrs, ConstantAttrs) assert isinstance(in2_mul_attrs, AwesomeAttributes) else: raise ValueError() assert isinstance(add_attrs, AddActivationAttrs) assert isinstance(in1_const_attrs, ConstantAttrs) assert isinstance(in1_mul_attrs, AwesomeAttributes) return ConstantMultiplyAddAttrs( ScalarType.float32, add_attrs.add_attrs.lhs_input_shape, add_attrs.add_attrs.rhs_input_shape, in1_const_attrs, in2_const_attrs )
[docs] class RenamerWithAttrsChecker(tvm.relay.frontend.onnx.Renamer): def __init__(self, new_name, checker): super().__init__(new_name)
[docs] self.checker = checker
def __call__(self, inputs, attrs, *args): if "precision" not in attrs: raise TypeError(f"Parameter precision must be specified for annotation node.") precision = attrs["precision"].decode("utf-8") if not self.checker(precision): raise TypeError(f"Attribute {precision} is not a valid precision annotation.") return super().__call__(inputs, attrs, *args)
[docs] AVAILABLE_ANNOTATIONS = ["int8", "int16"]
[docs] CUSTOM_CONVERT_MAP_DICT = { "ONNX": { GroupConv2DTranspose.ONNX_op_name: GroupConv2DTranspose.map_impl(), "AnnotatePrecision": RenamerWithAttrsChecker(tvm_def.TVM_PRECISION_HINT, lambda x: x in AVAILABLE_ANNOTATIONS), "AnnotateSensitivity": tvm.relay.frontend.onnx.Renamer(tvm_def.TVM_SENSITIVITY_HINT) }, }
[docs] TVM_FUNC_TO_AWESOME_ATTRIBUTES_DICT = { "group_conv2d_transpose": GroupConv2DTranspose.tvm_func_to_awesome_attributes, "constant_multiply_add": ConstantMultiplyAdd.tvm_func_to_awesome_attributes }