Source code for afe.backends.mla.quantized_mla_checkers

#########################################################
# Copyright (C) 2022 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Ljubomir Papuga
#########################################################
"""
This file contains the QuantizedMLAChecker class and the
check functions for all MLA supported IRs for pre-quantized
models.
"""

from typing import Dict

import afe.backends.checker_utils as utils
import afe._tvm._defines as tvm_def
from afe.backends import Backend, ExprCheckInfo, BaseChecker
from afe.backends.backend_checker import Decision, Accept, Reject, Predicate, pany, pall, paccept
from afe.backends.mla.mla_checkers import binary_elementwise_test, reduce_test, avg_pool_test, conv_transpose_test, \
    max_pool_test, tuple_concat_test, image_resize_test

# General check which apply to ANY operator assigned to the MLA
[docs] baseline_int_test = pall([ utils.all_tensors_are_int8, utils.input_is_4d_or_5d, utils.zero_axis_of_input_has_shape_of_1 ])
def _non_qnn_test(args: ExprCheckInfo) -> Decision: # The last entry in CheckerAttr list should contain output types attrs = args.attrs[-1] if all([ot in ('int8', 'int32') for ot in attrs.output_types]): return Accept() else: return Reject([]) # Do not accept broadcastable inputs due to a bug, SWMLA-3810
[docs] binary_quantized_elementwise_test = pall([ utils.all_tensors_are_int8, utils.zero_axis_of_input_has_shape_of_1, utils.binary_operator_have_same_input_shapes_or_one_scalar_input ])
[docs] MLA_BINARY_ELEMENTWISE_OPS = { 'qnn_add', 'qnn_subtract', 'qnn_mul' }
[docs] MLA_EXEMPT_FROM_BASELINE_CHECKS = set(MLA_BINARY_ELEMENTWISE_OPS)
def _make_quantized_mla_checkers() -> Dict[str, Predicate]: mla_checks: Dict[str, Predicate] = dict() for op in MLA_BINARY_ELEMENTWISE_OPS: mla_checks[op] = binary_quantized_elementwise_test mla_checks['qnn_conv2d_bias'] = paccept mla_checks['qnn_avg_pool2d'] = avg_pool_test mla_checks['qnn_leaky_relu'] = paccept mla_checks[tvm_def.TVM_MAX_POOL2D] = max_pool_test # Same TVM operator for int and float mla_checks['qnn_conv2d_transpose'] = conv_transpose_test mla_checks['qnn_conv2d_transpose_add'] = conv_transpose_test mla_checks['qnn_concatenate'] = tuple_concat_test mla_checks['image.resize2d'] = image_resize_test # Same TVM operator for int and float # Not included due to missing backend support: # qnn_sum, qnn_avg_pool3d, qnn_requantize, qnn_dense # Add the baseline checks to all operators that are supported and are not # in the exempt list for op, check in mla_checks.items(): if op not in MLA_EXEMPT_FROM_BASELINE_CHECKS: mla_checks[op] = pall([baseline_int_test, check]) return mla_checks _quantized_mla_checks = _make_quantized_mla_checkers()
[docs] def test_mla(args: ExprCheckInfo) -> Decision: test = _quantized_mla_checks.get(args.name) return test(args) if test is not None else Reject([])
[docs] class QuantizedMLAChecker(BaseChecker): """ The factory class for registering MLA IR checker for pre-quantized models. """ _checkers_name: str = "Quantized MLA Checker" _backend: Backend = Backend.MLA _predicate: Predicate = test_mla