Source code for afe.backends.mla.mla_checkers

#########################################################
# Copyright (C) 2021 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Joey Chou, Jeffrey Spitz
#########################################################
"""
This file contains the MLAChecker class and the check
functions for all MLA supported IRs.
"""
# TODO(Joey): Add description for each check function.
from typing import Dict

from afe.backends import Backend, ExprCheckInfo, BaseChecker
import afe.backends.checker_utils as utils
import afe._tvm._defines as tvm_def
from afe.backends.backend_checker import Decision, Reject, Predicate, pany, pall, paccept

import sima_utils.logging.sima_logger as sima_logger

[docs] FORCE_SUPPORT_ALIGN_CORNER_TRUE = True
# General check which apply to ANY operator assigned to the MLA
[docs] baseline_test = pall([ utils.all_tensors_are_float32, utils.input_is_4d_or_5d, utils.zero_axis_of_input_has_shape_of_1 ])
# Conditions that are shared by avgpool and maxpool
[docs] pool_test = pall([ utils.dilations_of_pooling_expression_are_greater_than_1, utils.pooling_expression_ceil_mode_is_false, utils.supported_pool_size, utils.pool_explicit_padding_incorrect_dimension ])
[docs] max_pool_test = pool_test
[docs] avg_pool_test = pall([ utils.padding_is_present_and_included, pool_test ])
# Test for all supported binary elementwise operators. # MLA supports more limited broadcasting than TVM does, and so # the test checks for a supported form of broadcasting.
[docs] binary_elementwise_test = pall([ utils.all_tensors_are_float32, utils.zero_axis_of_input_has_shape_of_1, pany([ utils.binary_operator_have_same_input_shapes_or_one_scalar_input, utils.binary_operator_inputs_are_broadcastable ]) ])
# In general the MLA accepts reduce ops where: # * The input is 4D # * The dimensions of the output shape match the input shape (keepdims=True) # * We are reducing along nonzero axes, or we reduce along zero axis but the zero axis is singular
[docs] reduce_test = pall([ utils.input_is_4d_or_5d, utils.keepdims_of_reduce_expression_is_true, utils.axes_are_nonzero_or_zero_axis_of_input_has_shape_of_1, utils.supported_mean_operator_axes_and_input_size ])
[docs] conv_transpose_test = pall([ utils.dilations_of_conv_transpose_expression_are_1, utils.depthwise_or_number_of_groups_is_1, utils.stride_is_power_of_2_upto_16 ])
# Checker for group_conv2d_transpose. # MLA supports group conv2d_transpose when groups number == channels, # which represents a depthwise conv2d_transpose. # # If supported, AFE will translate the depthwise conv2d_transpose # to upscale + depthwise conv2d to the MLA backend graph
[docs] group_conv2d_transpose_test = pall([ utils.number_of_groups_equals_number_of_input_channels, utils.stride_is_1_or_2, ])
# TODO SWMLA-3145, upsampling is deprecated, map to Resize instead
[docs] upsampling_test = pall([ utils.check_resize_is_supported, utils.corners_not_aligned if not FORCE_SUPPORT_ALIGN_CORNER_TRUE else paccept ])
[docs] image_resize_test = pall([ utils.check_transformation_mode_is_other_than_tf_crop_resize, utils.pytorch_half_pixel_is_half_pixel, utils.check_resize_is_supported ])
# Don't support batch dimension concat
[docs] tuple_concat_test = utils.tuple_concat_axis_is_not_0
# Compiler support strided slice if: # * input is 4D tensor # * strides are all 1 # * begin of channel axis is multiply of 16 and end of channel # axis is multiple of 16 (except for the last slice)
[docs] strided_slice_test = pall([ utils.input_is_4d, utils.strided_slice_stride_is_1, utils.strided_slice_on_chanel_axis_is_mul_of_16 ])
# Should check that the input type of argmax is float32, but # this isn't possible with the data currently in ExprCheckInfo.
[docs] argmax_test = pall([ utils.input_is_4d, utils.zero_axis_of_input_has_shape_of_1, utils.keepdims_of_reduce_expression_is_true, utils.reduce_axis_is_3, utils.all_tensors_are_int32 ])
[docs] softmax_test = pall([ utils.input_is_4d, utils.softmax_axis_is_3 ])
[docs] variance_test = pall([ utils.supported_variance_axis ])
[docs] gridsample_test = pall([ utils.supported_gridsample ])
[docs] def tuple_einsum_test(num_inputs: int) -> Predicate: return pall([utils.check_number_of_einsum_inputs(num_inputs), utils.einsum_equation_is_supported])
[docs] conv_test = pall([ utils.check_strides_in_conv ])
[docs] broadcast_to_test = pall([ utils.broadcast_to_output_is_4d_and_broadcastable ])
[docs] transpose_test = pall([ utils.check_transpose_affecting_batch_axis ])
# Quantization not implemented for tvm_def.TVM_PROD, tvm_def.TVM_SUM # Compiler not implemented for tvm_def.TVM_MIN, tvm_def.TVM_MAX
[docs] MLA_REDUCE_OPS = [tvm_def.TVM_MEAN]
[docs] MLA_ADAPTIVE_POOL_OPS = [tvm_def.TVM_ADAPTIVE_AVG_POOL2D, tvm_def.TVM_ADAPTIVE_MAX_POOL2D]
[docs] MLA_ELEMENTWISE_OPS = [tvm_def.TVM_ADD, 'add_relu', 'add_clip', tvm_def.TVM_MULTIPLY, tvm_def.TVM_SUBTRACT, tvm_def.TVM_DIVIDE, tvm_def.TVM_BIAS_ADD]
# These operators have no requirements beyond the baseline checks
[docs] MLA_BASELINE_CHECKS_ONLY_OPS = [ tvm_def.TVM_GLOBAL_AVG_POOL2D, tvm_def.TVM_GLOBAL_AVG_POOL3D, tvm_def.TVM_RELU, tvm_def.TVM_LEAKY_RELU, tvm_def.TVM_SIGMOID, tvm_def.TVM_TANH, tvm_def.TVM_EXP, tvm_def.TVM_LOG, tvm_def.TVM_SQRT, tvm_def.TVM_LRN, tvm_def.TVM_RSQRT, tvm_def.TVM_LOG2, tvm_def.TVM_LOG10, tvm_def.TVM_CLIP, tvm_def.TVM_DEPTH_TO_SPACE, 'swish', 'hard_swish', 'hard_sigmoid', 'constant_multiply_add', 'elu', 'softplus', 'erf', 'layer_norm', 'slice_concat', 'rms_norm', 'gelu', 'instance_norm' ]
[docs] MLA_CONV_OPS = [ # 2D 'conv2d', 'conv2d_clip', 'conv2d_relu', 'conv2d_add', 'conv2d_add_clip', 'conv2d_add_relu', 'conv2d_add_mul', # 3D 'conv3d', 'conv3d_add', 'conv3d_add_relu', ]
[docs] MLA_CONV_TRANSPOSE_OPS = [ # 2D 'conv2dtranspose', 'conv2dtranspose_clip', 'conv2dtranspose_relu', 'conv2dtranspose_add', 'conv2dtranspose_add_clip', 'conv2dtranspose_add_relu', # 3D 'conv3dtranspose', 'conv3dtranspose_add', ]
[docs] MLA_EXEMPT_FROM_BASELINE_CHECKS = { # TODO SWMLA-3144 'group_conv2d_transpose', tvm_def.TVM_ARGMAX, }.union(set(MLA_ELEMENTWISE_OPS))
# These operators can support 5D input
[docs] MLA_5D_OPS = [ 'conv3d', 'conv3d_add', 'conv3d_add_relu', tvm_def.TVM_AVG_POOL3D, 'variance', 'instance_norm', 'conv3dtranspose', 'conv3dtranspose_add', 'tuple_concat', ]
def _create_mla_checks() -> Dict[str, Predicate]: mla_checks: Dict[str, Predicate] = {} for op in MLA_BASELINE_CHECKS_ONLY_OPS: mla_checks[op] = paccept for op in MLA_ELEMENTWISE_OPS: mla_checks[op] = binary_elementwise_test for op in MLA_REDUCE_OPS: mla_checks[op] = reduce_test for op in MLA_CONV_OPS: mla_checks[op] = conv_test for op in MLA_CONV_TRANSPOSE_OPS: mla_checks[op] = conv_transpose_test mla_checks[tvm_def.TVM_AVG_POOL2D] = avg_pool_test mla_checks[tvm_def.TVM_AVG_POOL3D] = avg_pool_test mla_checks[tvm_def.TVM_MAX_POOL2D] = max_pool_test mla_checks['group_conv2d_transpose'] = group_conv2d_transpose_test # PRelu must use channel axis 3. TODO verify, SWMLA-3144 mla_checks['prelu_const'] = utils.prelu_axis_is_3 # Only support 'constant' pad_mode. Don't support batch dimension padding. mla_checks[tvm_def.TVM_PAD] = utils.pad_mode_is_constant_and_pad_width_is_0 mla_checks[tvm_def.TVM_UPSAMPLING] = upsampling_test mla_checks[tvm_def.TVM_IMAGE_RESIZE2D] = image_resize_test mla_checks['tuple_concat'] = tuple_concat_test mla_checks[tvm_def.TVM_STRIDED_SLICE] = strided_slice_test mla_checks[tvm_def.TVM_ARGMAX] = argmax_test mla_checks[tvm_def.TVM_SOFTMAX] = softmax_test mla_checks['tuple_einsum'] = tuple_einsum_test(num_inputs=2) mla_checks['single_input_tuple_einsum'] = tuple_einsum_test(num_inputs=1) mla_checks[tvm_def.TVM_BROADCAST_TO] = broadcast_to_test mla_checks[tvm_def.TVM_TRANSPOSE] = transpose_test mla_checks[tvm_def.TVM_VARIANCE] = variance_test mla_checks['grid_sample'] = gridsample_test # Add the baseline checks to all operators that are supported and are not # in the exempt list for op, check in mla_checks.items(): if op not in MLA_EXEMPT_FROM_BASELINE_CHECKS: mla_checks[op] = pall([baseline_test, check]) # Restrict 4D tensors for non-5D operators for op, check in mla_checks.items(): if op not in MLA_5D_OPS: mla_checks[op] = pall([ pany([ utils.input_is_4d, utils.binary_operator_have_same_input_shapes_or_one_scalar_input ]), check ]) return mla_checks _mla_checks = _create_mla_checks() def _check_mla(args: ExprCheckInfo) -> Decision: # Dispatch the check according to the operator. # Reject with code -1 if the operator is not supported. test = _mla_checks.get(args.name) decision = test(args) if test is not None else Reject([-1]) if isinstance(decision, Reject): sima_logger.sima_log_info(f"Cannot assign node {args.name}_{args.idx} to MLA.\n\t" f"{utils.unsupported_op_code_to_message(decision.error_codes)}") return decision
[docs] class MLAChecker(BaseChecker): """ Checker class for MLA backend. See BaseChecker for detailed documentation. """ _checkers_name: str = "MLA Checker" _backend: Backend = Backend.MLA _predicate = _check_mla