#########################################################
# Copyright (C) 2021 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Joey Chou, Jeffrey Spitz
#########################################################
"""
This file contains the MLAChecker class and the check
functions for all MLA supported IRs.
"""
# TODO(Joey): Add description for each check function.
from typing import Dict
from afe.backends import Backend, ExprCheckInfo, BaseChecker
import afe.backends.checker_utils as utils
import afe._tvm._defines as tvm_def
from afe.backends.backend_checker import Decision, Reject, Predicate, pany, pall, paccept
import sima_utils.logging.sima_logger as sima_logger
[docs]
FORCE_SUPPORT_ALIGN_CORNER_TRUE = True
# General check which apply to ANY operator assigned to the MLA
[docs]
baseline_test = pall([
utils.all_tensors_are_float32,
utils.input_is_4d_or_5d,
utils.zero_axis_of_input_has_shape_of_1
])
# Conditions that are shared by avgpool and maxpool
[docs]
pool_test = pall([
utils.dilations_of_pooling_expression_are_greater_than_1,
utils.pooling_expression_ceil_mode_is_false,
utils.supported_pool_size,
utils.pool_explicit_padding_incorrect_dimension
])
[docs]
max_pool_test = pool_test
[docs]
avg_pool_test = pall([
utils.padding_is_present_and_included,
pool_test
])
# Test for all supported binary elementwise operators.
# MLA supports more limited broadcasting than TVM does, and so
# the test checks for a supported form of broadcasting.
[docs]
binary_elementwise_test = pall([
utils.all_tensors_are_float32,
utils.zero_axis_of_input_has_shape_of_1,
pany([
utils.binary_operator_have_same_input_shapes_or_one_scalar_input,
utils.binary_operator_inputs_are_broadcastable
])
])
# In general the MLA accepts reduce ops where:
# * The input is 4D
# * The dimensions of the output shape match the input shape (keepdims=True)
# * We are reducing along nonzero axes, or we reduce along zero axis but the zero axis is singular
[docs]
reduce_test = pall([
utils.input_is_4d_or_5d,
utils.keepdims_of_reduce_expression_is_true,
utils.axes_are_nonzero_or_zero_axis_of_input_has_shape_of_1,
utils.supported_mean_operator_axes_and_input_size
])
[docs]
conv_transpose_test = pall([
utils.dilations_of_conv_transpose_expression_are_1,
utils.depthwise_or_number_of_groups_is_1,
utils.stride_is_power_of_2_upto_16
])
# Checker for group_conv2d_transpose.
# MLA supports group conv2d_transpose when groups number == channels,
# which represents a depthwise conv2d_transpose.
#
# If supported, AFE will translate the depthwise conv2d_transpose
# to upscale + depthwise conv2d to the MLA backend graph
[docs]
group_conv2d_transpose_test = pall([
utils.number_of_groups_equals_number_of_input_channels,
utils.stride_is_1_or_2,
])
# TODO SWMLA-3145, upsampling is deprecated, map to Resize instead
[docs]
upsampling_test = pall([
utils.check_resize_is_supported,
utils.corners_not_aligned if not FORCE_SUPPORT_ALIGN_CORNER_TRUE else paccept
])
[docs]
image_resize_test = pall([
utils.check_transformation_mode_is_other_than_tf_crop_resize,
utils.pytorch_half_pixel_is_half_pixel,
utils.check_resize_is_supported
])
# Don't support batch dimension concat
[docs]
tuple_concat_test = utils.tuple_concat_axis_is_not_0
# Compiler support strided slice if:
# * input is 4D tensor
# * strides are all 1
# * begin of channel axis is multiply of 16 and end of channel
# axis is multiple of 16 (except for the last slice)
[docs]
strided_slice_test = pall([
utils.input_is_4d,
utils.strided_slice_stride_is_1,
utils.strided_slice_on_chanel_axis_is_mul_of_16
])
# Should check that the input type of argmax is float32, but
# this isn't possible with the data currently in ExprCheckInfo.
[docs]
argmax_test = pall([
utils.input_is_4d,
utils.zero_axis_of_input_has_shape_of_1,
utils.keepdims_of_reduce_expression_is_true,
utils.reduce_axis_is_3,
utils.all_tensors_are_int32
])
[docs]
softmax_test = pall([
utils.input_is_4d,
utils.softmax_axis_is_3
])
[docs]
variance_test = pall([
utils.supported_variance_axis
])
[docs]
gridsample_test = pall([
utils.supported_gridsample
])
[docs]
def tuple_einsum_test(num_inputs: int) -> Predicate:
return pall([utils.check_number_of_einsum_inputs(num_inputs), utils.einsum_equation_is_supported])
[docs]
conv_test = pall([
utils.check_strides_in_conv
])
[docs]
broadcast_to_test = pall([
utils.broadcast_to_output_is_4d_and_broadcastable
])
[docs]
transpose_test = pall([
utils.check_transpose_affecting_batch_axis
])
# Quantization not implemented for tvm_def.TVM_PROD, tvm_def.TVM_SUM
# Compiler not implemented for tvm_def.TVM_MIN, tvm_def.TVM_MAX
[docs]
MLA_REDUCE_OPS = [tvm_def.TVM_MEAN]
[docs]
MLA_ADAPTIVE_POOL_OPS = [tvm_def.TVM_ADAPTIVE_AVG_POOL2D, tvm_def.TVM_ADAPTIVE_MAX_POOL2D]
[docs]
MLA_ELEMENTWISE_OPS = [tvm_def.TVM_ADD, 'add_relu', 'add_clip', tvm_def.TVM_MULTIPLY,
tvm_def.TVM_SUBTRACT, tvm_def.TVM_DIVIDE, tvm_def.TVM_BIAS_ADD]
# These operators have no requirements beyond the baseline checks
[docs]
MLA_BASELINE_CHECKS_ONLY_OPS = [
tvm_def.TVM_GLOBAL_AVG_POOL2D,
tvm_def.TVM_GLOBAL_AVG_POOL3D,
tvm_def.TVM_RELU,
tvm_def.TVM_LEAKY_RELU,
tvm_def.TVM_SIGMOID,
tvm_def.TVM_TANH,
tvm_def.TVM_EXP,
tvm_def.TVM_LOG,
tvm_def.TVM_SQRT,
tvm_def.TVM_LRN,
tvm_def.TVM_RSQRT,
tvm_def.TVM_LOG2,
tvm_def.TVM_LOG10,
tvm_def.TVM_CLIP,
tvm_def.TVM_DEPTH_TO_SPACE,
'swish',
'hard_swish',
'hard_sigmoid',
'constant_multiply_add',
'elu',
'softplus',
'erf',
'layer_norm',
'slice_concat',
'rms_norm',
'gelu',
'instance_norm'
]
[docs]
MLA_CONV_OPS = [
# 2D
'conv2d',
'conv2d_clip',
'conv2d_relu',
'conv2d_add',
'conv2d_add_clip',
'conv2d_add_relu',
'conv2d_add_mul',
# 3D
'conv3d',
'conv3d_add',
'conv3d_add_relu',
]
[docs]
MLA_CONV_TRANSPOSE_OPS = [
# 2D
'conv2dtranspose',
'conv2dtranspose_clip',
'conv2dtranspose_relu',
'conv2dtranspose_add',
'conv2dtranspose_add_clip',
'conv2dtranspose_add_relu',
# 3D
'conv3dtranspose',
'conv3dtranspose_add',
]
[docs]
MLA_EXEMPT_FROM_BASELINE_CHECKS = {
# TODO SWMLA-3144
'group_conv2d_transpose',
tvm_def.TVM_ARGMAX,
}.union(set(MLA_ELEMENTWISE_OPS))
# These operators can support 5D input
[docs]
MLA_5D_OPS = [
'conv3d',
'conv3d_add',
'conv3d_add_relu',
tvm_def.TVM_AVG_POOL3D,
'variance',
'instance_norm',
'conv3dtranspose',
'conv3dtranspose_add',
'tuple_concat',
]
def _create_mla_checks() -> Dict[str, Predicate]:
mla_checks: Dict[str, Predicate] = {}
for op in MLA_BASELINE_CHECKS_ONLY_OPS:
mla_checks[op] = paccept
for op in MLA_ELEMENTWISE_OPS:
mla_checks[op] = binary_elementwise_test
for op in MLA_REDUCE_OPS:
mla_checks[op] = reduce_test
for op in MLA_CONV_OPS:
mla_checks[op] = conv_test
for op in MLA_CONV_TRANSPOSE_OPS:
mla_checks[op] = conv_transpose_test
mla_checks[tvm_def.TVM_AVG_POOL2D] = avg_pool_test
mla_checks[tvm_def.TVM_AVG_POOL3D] = avg_pool_test
mla_checks[tvm_def.TVM_MAX_POOL2D] = max_pool_test
mla_checks['group_conv2d_transpose'] = group_conv2d_transpose_test
# PRelu must use channel axis 3. TODO verify, SWMLA-3144
mla_checks['prelu_const'] = utils.prelu_axis_is_3
# Only support 'constant' pad_mode. Don't support batch dimension padding.
mla_checks[tvm_def.TVM_PAD] = utils.pad_mode_is_constant_and_pad_width_is_0
mla_checks[tvm_def.TVM_UPSAMPLING] = upsampling_test
mla_checks[tvm_def.TVM_IMAGE_RESIZE2D] = image_resize_test
mla_checks['tuple_concat'] = tuple_concat_test
mla_checks[tvm_def.TVM_STRIDED_SLICE] = strided_slice_test
mla_checks[tvm_def.TVM_ARGMAX] = argmax_test
mla_checks[tvm_def.TVM_SOFTMAX] = softmax_test
mla_checks['tuple_einsum'] = tuple_einsum_test(num_inputs=2)
mla_checks['single_input_tuple_einsum'] = tuple_einsum_test(num_inputs=1)
mla_checks[tvm_def.TVM_BROADCAST_TO] = broadcast_to_test
mla_checks[tvm_def.TVM_TRANSPOSE] = transpose_test
mla_checks[tvm_def.TVM_VARIANCE] = variance_test
mla_checks['grid_sample'] = gridsample_test
# Add the baseline checks to all operators that are supported and are not
# in the exempt list
for op, check in mla_checks.items():
if op not in MLA_EXEMPT_FROM_BASELINE_CHECKS:
mla_checks[op] = pall([baseline_test, check])
# Restrict 4D tensors for non-5D operators
for op, check in mla_checks.items():
if op not in MLA_5D_OPS:
mla_checks[op] = pall([
pany([
utils.input_is_4d,
utils.binary_operator_have_same_input_shapes_or_one_scalar_input
]),
check
])
return mla_checks
_mla_checks = _create_mla_checks()
def _check_mla(args: ExprCheckInfo) -> Decision:
# Dispatch the check according to the operator.
# Reject with code -1 if the operator is not supported.
test = _mla_checks.get(args.name)
decision = test(args) if test is not None else Reject([-1])
if isinstance(decision, Reject):
sima_logger.sima_log_info(f"Cannot assign node {args.name}_{args.idx} to MLA.\n\t"
f"{utils.unsupported_op_code_to_message(decision.error_codes)}")
return decision
[docs]
class MLAChecker(BaseChecker):
"""
Checker class for MLA backend. See BaseChecker for detailed documentation.
"""
_checkers_name: str = "MLA Checker"
_backend: Backend = Backend.MLA
_predicate = _check_mla