Source code for afe.core.mixed_precision.config

#########################################################
# Copyright (C) 2024 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Shreyas Kera
#########################################################
import torch
import math
import operator

import model_compression_toolkit as mct

from typing import Any, Dict, List, Optional, Tuple


[docs] def get_core_config() -> mct.core.CoreConfig: """ Get mct's core config, with details about error methods, bias correction, etc. TODO: Currently a fixed config, needs to be configurable based on user's choices from AFE specs. :return: core_config. CoreConfig with quantization specific details. """ core_config = mct.core.CoreConfig( mct.core.QuantizationConfig( activation_error_method=mct.core.QuantizationErrorMethod.MSE, weights_error_method=mct.core.QuantizationErrorMethod.NOCLIPPING, relu_bound_to_power_of_2=False, weights_bias_correction=True, weights_second_moment_correction=False, input_scaling=False, softmax_shift=False, shift_negative_activation_correction=False, activation_channel_equalization=False, z_threshold=math.inf, min_threshold=2 ** -16, l_p_value=2, linear_collapsing=False, residual_collapsing=False, shift_negative_ratio=0.05, shift_negative_threshold_recalculation=False, shift_negative_params_search=False ) ) return core_config
def _get_op_quantization_configs(activation_n_bits: int) -> List[mct.target_platform.OpQuantizationConfig]: """ Get list of quantization configuration parameters. :param activation_n_bits: Number of bits used in quantization. :return: Configuration options for operator quantization. """ op_quant_config = mct.target_platform.op_quantization_config default_weight_attr_config = op_quant_config.AttributeQuantizationConfig( weights_quantization_method=mct.target_platform.QuantizationMethod.SYMMETRIC, weights_n_bits=8, weights_per_channel_threshold=False, enable_weights_quantization=False, lut_values_bitwidth=None ) kernel_base_config = op_quant_config.AttributeQuantizationConfig( weights_quantization_method=mct.target_platform.QuantizationMethod.SYMMETRIC, weights_n_bits=8, weights_per_channel_threshold=True, enable_weights_quantization=True, lut_values_bitwidth=None ) bias_config = op_quant_config.AttributeQuantizationConfig( weights_quantization_method=mct.target_platform.QuantizationMethod.SYMMETRIC, weights_n_bits=32, weights_per_channel_threshold=False, enable_weights_quantization=False, lut_values_bitwidth=None ) linear_eight_bits = mct.target_platform.OpQuantizationConfig( activation_quantization_method=mct.target_platform.QuantizationMethod.UNIFORM, default_weight_attr_config=default_weight_attr_config, attr_weights_configs_mapping={"kernel_attr": kernel_base_config, "bias_attr": bias_config}, activation_n_bits=activation_n_bits, enable_activation_quantization=True, quantization_preserving=False, fixed_scale=None, fixed_zero_point=None, simd_size=None ) return [linear_eight_bits]
[docs] def get_model_target_platform_capabilities(activation_n_bits: int) \ -> mct.target_platform.TargetPlatformCapabilities: """ Get mct's target platform capabilities that are used for quantization like quantization method, per channel weights, layer fusing etc. TODO: Currently fixed specifications, but need to modified based on quantization params and AFE's layer fusions. :param activation_n_bits: Bit setting to quantize to. :return: pytorch_tpc. Target Platform Capability object with quantization details. """ tp = mct.target_platform default_quant_options = tp.get_default_quantization_config_options quant_config = _get_op_quantization_configs(activation_n_bits) default_configuration_options = tp.QuantizationConfigOptions(quant_config) generated_tpc = tp.TargetPlatformModel(default_configuration_options, name='mixed_prec') with generated_tpc: opset_config_options: Dict[str, Optional[tp.QuantizationConfigOptions]] = { "NoQuantization": default_quant_options().clone_and_edit(quantization_preserving=True), "FullyConnected": default_quant_options().clone_and_edit_weight_attribute( weights_per_channel_threshold=False), "L2Normalization": default_quant_options().clone_and_edit(fixed_zero_point=0, fixed_scale=1 / 128), "LogSoftmax": default_quant_options().clone_and_edit(fixed_zero_point=127, fixed_scale=16 / 256), "Tanh": default_quant_options().clone_and_edit(fixed_zero_point=0, fixed_scale=1 / 128), "Softmax": default_quant_options().clone_and_edit(fixed_zero_point=-128, fixed_scale=1 / 256), "Logistic": default_quant_options().clone_and_edit(fixed_zero_point=-128, fixed_scale=1 / 256), "Relu": None, "Elu": None, "BatchNorm": None, "BiasAdd": None, "Squeeze": default_quant_options().clone_and_edit(quantization_preserving=True), } for opset_name, config_option in opset_config_options.items(): tp.OperatorsSet(opset_name, config_option) conv2d = tp.OperatorsSet("Conv2d") add = tp.OperatorsSet("Add") tp.Fusing([conv2d, add]) pytorch_tpc = tp.TargetPlatformCapabilities(generated_tpc, name='mixed_prec', version='v1') with pytorch_tpc: opset_to_layers_dict: Dict[str, Tuple[List[Any, Optional[Dict[str, mct.defaultdict.DefaultDict]]]]] = { "NoQuantization": ( [ torch.nn.AvgPool2d, torch.nn.functional.avg_pool2d, torch.cat, torch.concat, torch.nn.MaxPool2d, torch.nn.functional.max_pool2d, torch.mul, torch.multiply, torch.reshape, tp.LayerFilterParams(torch.nn.functional.interpolate, mode='bilinear'), torch.nn.ZeroPad2d, torch.gather, torch.transpose, torch.maximum, torch.max, torch.minimum, torch.min, torch.nn.functional.pad, torch.select, torch.unbind ], None ), "FullyConnected": ( [torch.nn.Linear, torch.nn.functional.linear], { "kernel_attr": mct.defaultdict.DefaultDict(default_value="weight"), "bias_attr": mct.defaultdict.DefaultDict(default_value="bias") } ), "LogSoftmax": ([torch.nn.LogSoftmax], None), "Tanh": ([torch.nn.Tanh, torch.nn.functional.tanh], None), "Softmax": ([torch.nn.Softmax, torch.nn.functional.softmax], None), "Logistic": ([torch.nn.Sigmoid, torch.nn.functional.sigmoid], None), "Conv2d": ([torch.nn.Conv2d, torch.nn.functional.conv2d], None), "Relu": ( [ torch.relu, torch.nn.ReLU, torch.nn.ReLU6, torch.nn.functional.relu, torch.nn.functional.relu6, tp.LayerFilterParams(torch.nn.Hardtanh, min_val=0, max_val=6), tp.LayerFilterParams(torch.nn.functional.hardtanh, min_val=0, max_val=6) ], None ), "Elu": ([torch.nn.ELU, torch.nn.functional.elu], None), "BatchNorm": ([torch.nn.BatchNorm2d, torch.nn.functional.batch_norm], None), "Squeeze": ([torch.squeeze], None), "Add": ([operator.add, torch.add], None), } for opset_name, (layers, attr_mapping) in opset_to_layers_dict.items(): tp.OperationsSetToLayers(opset_name, layers, attr_mapping) return pytorch_tpc