#########################################################
# Copyright (C) 2021 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Joey Chou
#########################################################
"""
TODO(Joey.Chou):
Use registration method to register
* custom_convert_map
* MergeComposite Pattern _PATTERN_TABLE
* tvm_func_to_awesome_attributes
"""
import copy
import numpy as np
import tvm
from tvm import relay
from tvm.relay.frontend.common import infer_type
from afe.backends import Backend
from afe.ir.attributes import AwesomeAttributes
from afe.ir.tensor_type import ScalarType
import afe.ir.utils as utils
import afe._tvm._defines as tvm_def
from afe._tvm._utils import get_expr_input_shapes, get_expr_output_shapes
from afe.tvm_converter.parameters import TVMConverterParams
[docs]
class GroupConv2DTranspose():
"""
Models: CenterNet
Author: Alicja Kwasniewska, Joey Chou
Decompose a nn.conv2d_transpose with groups > 1 into splits, multiple nn.conv2d_transpose,
and concatenate as below:
|
split
--------------------------- ....
| | | | | ...
tuple_get_item .....
|
nn.conv2d_transpose
|
| | | | | ...
--------------------------- ....
|
tuple
|
concatenate
|
"""
[docs]
ONNX_op_name = "ConvTranspose"
@staticmethod
[docs]
def map_impl():
def _conv_single_group_rec(data_splitted, weight_splitted, attr, groups, out=None):
if groups > 0:
groups = groups - 1
data_slice = relay.TupleGetItem(data_splitted, groups)
weight_slice = relay.TupleGetItem(weight_splitted, groups)
conv = relay.nn.conv2d_transpose(data_slice,
weight_slice,
strides=attr['strides'],
padding=attr['pads'],
dilation=attr['dilations'],
kernel_size=attr['kernel_shape'],
)
if out:
out.insert(0, conv)
else:
out = []
out.append(conv)
return _conv_single_group_rec(data_splitted, weight_splitted, attr, groups, out)
else:
out = relay.concatenate(out, axis=1)
return out
def _impl(inputs, attr, params):
assert 2 <= len(inputs) <= 3, (f"Inputs of ONNX's {GroupConv2DTranspose.ONNX_op_name} operators "
f"must have either 2 inputs (data, weights) or 3 inputs (data, weights, "
f"bias). Got {len(inputs)} inputs")
# default strides if they're not specified in attr
data_shape = infer_type(inputs[0]).checked_type.shape
weight_shape = infer_type(inputs[1]).checked_type.shape
default_strides = (1,) * (len(data_shape) - 2)
attr['strides'] = attr['strides'] if 'strides' in attr.keys() else default_strides
# In ONNX, the default data dimensions is NCHW
data_channel_axis = 1
# In ONNX, the default weight dimensions is OIHW
weight_output_channe_axis = 0
groups = attr.get('group', 1)
if groups == 1 or utils.is_depthwise_conv(data_shape[data_channel_axis],
weight_shape[weight_output_channe_axis], groups):
# Use TVM's translation for transposed convolution
from tvm.relay.frontend.onnx import ConvTranspose
return ConvTranspose.get_converter(opset=17)(inputs, attr, params)
# Else, it is transposed convolution with multiple groups.
# There is no SiMa IR support for grouped transposed convolution,
# so override TVM's frontend to decompose it into multiple transposed convolutions.
assert groups > 1
data_splitted = relay.split(inputs[0], groups, axis=data_channel_axis).astuple()
weight_splitted = relay.split(inputs[1], groups, axis=weight_output_channe_axis).astuple()
out = _conv_single_group_rec(data_splitted, weight_splitted, attr, groups)
# If there are 3 inputs, the last input is the bias tensor
if len(inputs) == 3:
out = relay.nn.bias_add(out, inputs[2])
return out
return _impl
@staticmethod
[docs]
def tvm_func_to_awesome_attributes(params: TVMConverterParams, func: tvm_def.TVMFunction) -> AwesomeAttributes:
"""
Functions to extract decomposed group conv2d_transpose Relay attributes and
translate it into SiMa IR Conv2DTransposeAttrs
"""
from afe.ir.attributes import ConstantAttrs, ConvAddActivationAttrs
from afe.tvm_converter._attributes import (
get_tvm_call_op_attr_in_dict, get_tvm_constant_attr_in_dict,
extract_expressions_from_function_used_in_composite_function_attributes
)
from afe.tvm_converter._operators import _make_conv_attributes
# Extract all Relay attributes in the composite function
composite_fn_expressions = extract_expressions_from_function_used_in_composite_function_attributes(func)
# Get the conv2d_transpose_attrs_list from relay.nn.conv2d_transpose
conv2d_transpose_expressions = [expr for expr in composite_fn_expressions if isinstance(expr, tvm_def.TVMCall)
and expr.op.name == "nn.conv2d_transpose"]
layout_transform_expressions = [expr for expr in composite_fn_expressions if isinstance(expr, tvm_def.TVMCall)
and expr.op.name == "layout_transform"]
weight_data = [get_tvm_constant_attr_in_dict(expr)['data']
for expr in composite_fn_expressions if isinstance(expr, tvm_def.TVMConstant)]
conv2d_transpose_attrs_list = \
[get_tvm_call_op_attr_in_dict(expr) for expr in conv2d_transpose_expressions]
# Get input shape by transposing to a correct layout and then multiplying the Channels with number of groups.
if len(layout_transform_expressions) > 0:
current_layout = layout_transform_expressions[0].attrs.src_layout
else:
current_layout = "NHWC"
dst_layout = "NHWC"
input_shape = get_expr_input_shapes(conv2d_transpose_expressions[0])[0]
input_shape = utils.transpose_attr_according_to_layout_strings(input_shape, dst_layout, current_layout)
output_shape = get_expr_output_shapes(conv2d_transpose_expressions[0])[0]
output_shape = utils.transpose_attr_according_to_layout_strings(output_shape, dst_layout, current_layout)
# Create attributes for the grouped convolution corresponding to this composite operator
groups = len(weight_data)
conv2d_transpose_attrs = conv2d_transpose_attrs_list[0]
grouped_attrs = copy.deepcopy(conv2d_transpose_expressions[0].attrs)
grouped_attrs.groups = groups
grouped_attrs.channels *= groups
grouped_attrs.data_layout = "NHWC"
grouped_attrs.out_layout = "NHWC"
# Not changed: kernel_size, strides, padding, output_padding, dilation,
# kernel_layout, out_dtype
# Create final weights_attrs
weight_output_channe_axis = conv2d_transpose_attrs['kernel_layout'].index("O")
weight = np.concatenate(weight_data, axis=weight_output_channe_axis)
# Create SiMa IR attributes
input_shape[current_layout.index('C')] *= groups
output_shape[current_layout.index('C')] *= groups
weight, attrs = _make_conv_attributes(tuple(input_shape), 'float32', tuple(output_shape), weight, grouped_attrs)
weights_attrs = ConstantAttrs(data=weight)
return ConvAddActivationAttrs(weights_attrs=weights_attrs, conv_attrs=attrs)
[docs]
class ConstantMultiplyAdd():
"""
Merge multiply with constant to add op.
Do it for both inputs if possible.
"""
@staticmethod
[docs]
def tvm_func_to_awesome_attributes(params: TVMConverterParams, func: tvm_def.TVMFunction) -> AwesomeAttributes:
"""
Functions to extract decomposed constant, mul, add relay attributes from composite
constant_multiply_add and translate it into SiMa IR ConstantMultiplyAddAttrs.
"""
from afe.ir.attributes import ConstantAttrs, ConstantMultiplyAddAttrs, AddActivationAttrs
from afe.tvm_converter._attributes import (
get_non_global_var_awesome_attrs,
extract_expressions_from_function_used_in_composite_function_attributes
)
# Extract all Relay attributes in the composite function
composite_fn_expressions = extract_expressions_from_function_used_in_composite_function_attributes(func)
# Convert each Relay attribute to its corresponding AwesomeAttributes
awesome_attrs = [get_non_global_var_awesome_attrs(expr, params, Backend.MLA)[0] for expr in composite_fn_expressions]
in2_const_attrs, in2_mul_attrs = None, None
if len(awesome_attrs) == 3:
# Only one input of add op is multiply with constant
in1_const_attrs, in1_mul_attrs, add_attrs = awesome_attrs
elif len(awesome_attrs) == 5:
# Both inputs of add op are multiply with constant
in1_const_attrs, in1_mul_attrs, in2_const_attrs, in2_mul_attrs, add_attrs = awesome_attrs
assert isinstance(in2_const_attrs, ConstantAttrs)
assert isinstance(in2_mul_attrs, AwesomeAttributes)
else:
raise ValueError()
assert isinstance(add_attrs, AddActivationAttrs)
assert isinstance(in1_const_attrs, ConstantAttrs)
assert isinstance(in1_mul_attrs, AwesomeAttributes)
return ConstantMultiplyAddAttrs(
ScalarType.float32, add_attrs.add_attrs.lhs_input_shape,
add_attrs.add_attrs.rhs_input_shape, in1_const_attrs, in2_const_attrs
)
[docs]
class RenamerWithAttrsChecker(tvm.relay.frontend.onnx.Renamer):
def __init__(self, new_name, checker):
super().__init__(new_name)
def __call__(self, inputs, attrs, *args):
if "precision" not in attrs:
raise TypeError(f"Parameter precision must be specified for annotation node.")
precision = attrs["precision"].decode("utf-8")
if not self.checker(precision):
raise TypeError(f"Attribute {precision} is not a valid precision annotation.")
return super().__call__(inputs, attrs, *args)
[docs]
AVAILABLE_ANNOTATIONS = ["int8", "int16"]
[docs]
CUSTOM_CONVERT_MAP_DICT = {
"ONNX":
{
GroupConv2DTranspose.ONNX_op_name: GroupConv2DTranspose.map_impl(),
"AnnotatePrecision": RenamerWithAttrsChecker(tvm_def.TVM_PRECISION_HINT,
lambda x: x in AVAILABLE_ANNOTATIONS),
"AnnotateSensitivity": tvm.relay.frontend.onnx.Renamer(tvm_def.TVM_SENSITIVITY_HINT)
},
}
[docs]
TVM_FUNC_TO_AWESOME_ATTRIBUTES_DICT = {
"group_conv2d_transpose": GroupConv2DTranspose.tvm_func_to_awesome_attributes,
"constant_multiply_add": ConstantMultiplyAdd.tvm_func_to_awesome_attributes
}