#########################################################
# Copyright (C) 2022 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Ljubomir Papuga
#########################################################
"""
This file contains the QuantizedMLAChecker class and the
check functions for all MLA supported IRs for pre-quantized
models.
"""
from typing import Dict
import afe.backends.checker_utils as utils
import afe._tvm._defines as tvm_def
from afe.backends import Backend, ExprCheckInfo, BaseChecker
from afe.backends.backend_checker import Decision, Accept, Reject, Predicate, pany, pall, paccept
from afe.backends.mla.mla_checkers import binary_elementwise_test, reduce_test, avg_pool_test, conv_transpose_test, \
max_pool_test, tuple_concat_test, image_resize_test
# General check which apply to ANY operator assigned to the MLA
[docs]
baseline_int_test = pall([
utils.all_tensors_are_int8,
utils.input_is_4d_or_5d,
utils.zero_axis_of_input_has_shape_of_1
])
def _non_qnn_test(args: ExprCheckInfo) -> Decision:
# The last entry in CheckerAttr list should contain output types
attrs = args.attrs[-1]
if all([ot in ('int8', 'int32') for ot in attrs.output_types]):
return Accept()
else:
return Reject([])
# Do not accept broadcastable inputs due to a bug, SWMLA-3810
[docs]
binary_quantized_elementwise_test = pall([
utils.all_tensors_are_int8,
utils.zero_axis_of_input_has_shape_of_1,
utils.binary_operator_have_same_input_shapes_or_one_scalar_input
])
[docs]
MLA_BINARY_ELEMENTWISE_OPS = {
'qnn_add',
'qnn_subtract',
'qnn_mul'
}
[docs]
MLA_EXEMPT_FROM_BASELINE_CHECKS = set(MLA_BINARY_ELEMENTWISE_OPS)
def _make_quantized_mla_checkers() -> Dict[str, Predicate]:
mla_checks: Dict[str, Predicate] = dict()
for op in MLA_BINARY_ELEMENTWISE_OPS:
mla_checks[op] = binary_quantized_elementwise_test
mla_checks['qnn_conv2d_bias'] = paccept
mla_checks['qnn_avg_pool2d'] = avg_pool_test
mla_checks['qnn_leaky_relu'] = paccept
mla_checks[tvm_def.TVM_MAX_POOL2D] = max_pool_test # Same TVM operator for int and float
mla_checks['qnn_conv2d_transpose'] = conv_transpose_test
mla_checks['qnn_conv2d_transpose_add'] = conv_transpose_test
mla_checks['qnn_concatenate'] = tuple_concat_test
mla_checks['image.resize2d'] = image_resize_test # Same TVM operator for int and float
# Not included due to missing backend support:
# qnn_sum, qnn_avg_pool3d, qnn_requantize, qnn_dense
# Add the baseline checks to all operators that are supported and are not
# in the exempt list
for op, check in mla_checks.items():
if op not in MLA_EXEMPT_FROM_BASELINE_CHECKS:
mla_checks[op] = pall([baseline_int_test, check])
return mla_checks
_quantized_mla_checks = _make_quantized_mla_checkers()
[docs]
def test_mla(args: ExprCheckInfo) -> Decision:
test = _quantized_mla_checks.get(args.name)
return test(args) if test is not None else Reject([])
[docs]
class QuantizedMLAChecker(BaseChecker):
"""
The factory class for registering MLA IR checker for pre-quantized models.
"""
_checkers_name: str = "Quantized MLA Checker"
_backend: Backend = Backend.MLA
_predicate: Predicate = test_mla