Source code for afe.core.mixed_precision.config
#########################################################
# Copyright (C) 2024 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Shreyas Kera
#########################################################
import torch
import math
import operator
import model_compression_toolkit as mct
from typing import Any, Dict, List, Optional, Tuple
[docs]
def get_core_config() -> mct.core.CoreConfig:
"""
Get mct's core config, with details about error methods, bias correction, etc.
TODO: Currently a fixed config, needs to be configurable based on user's choices from AFE specs.
:return: core_config. CoreConfig with quantization specific details.
"""
core_config = mct.core.CoreConfig(
mct.core.QuantizationConfig(
activation_error_method=mct.core.QuantizationErrorMethod.MSE,
weights_error_method=mct.core.QuantizationErrorMethod.NOCLIPPING,
relu_bound_to_power_of_2=False,
weights_bias_correction=True,
weights_second_moment_correction=False,
input_scaling=False,
softmax_shift=False,
shift_negative_activation_correction=False,
activation_channel_equalization=False,
z_threshold=math.inf,
min_threshold=2 ** -16,
l_p_value=2,
linear_collapsing=False,
residual_collapsing=False,
shift_negative_ratio=0.05,
shift_negative_threshold_recalculation=False,
shift_negative_params_search=False
)
)
return core_config
def _get_op_quantization_configs(activation_n_bits: int) -> List[mct.target_platform.OpQuantizationConfig]:
"""
Get list of quantization configuration parameters.
:param activation_n_bits: Number of bits used in quantization.
:return: Configuration options for operator quantization.
"""
op_quant_config = mct.target_platform.op_quantization_config
default_weight_attr_config = op_quant_config.AttributeQuantizationConfig(
weights_quantization_method=mct.target_platform.QuantizationMethod.SYMMETRIC,
weights_n_bits=8,
weights_per_channel_threshold=False,
enable_weights_quantization=False,
lut_values_bitwidth=None
)
kernel_base_config = op_quant_config.AttributeQuantizationConfig(
weights_quantization_method=mct.target_platform.QuantizationMethod.SYMMETRIC,
weights_n_bits=8,
weights_per_channel_threshold=True,
enable_weights_quantization=True,
lut_values_bitwidth=None
)
bias_config = op_quant_config.AttributeQuantizationConfig(
weights_quantization_method=mct.target_platform.QuantizationMethod.SYMMETRIC,
weights_n_bits=32,
weights_per_channel_threshold=False,
enable_weights_quantization=False,
lut_values_bitwidth=None
)
linear_eight_bits = mct.target_platform.OpQuantizationConfig(
activation_quantization_method=mct.target_platform.QuantizationMethod.UNIFORM,
default_weight_attr_config=default_weight_attr_config,
attr_weights_configs_mapping={"kernel_attr": kernel_base_config, "bias_attr": bias_config},
activation_n_bits=activation_n_bits,
enable_activation_quantization=True,
quantization_preserving=False,
fixed_scale=None,
fixed_zero_point=None,
simd_size=None
)
return [linear_eight_bits]
[docs]
def get_model_target_platform_capabilities(activation_n_bits: int) \
-> mct.target_platform.TargetPlatformCapabilities:
"""
Get mct's target platform capabilities that are used for quantization like quantization method,
per channel weights, layer fusing etc.
TODO: Currently fixed specifications, but need to modified based on quantization params and
AFE's layer fusions.
:param activation_n_bits: Bit setting to quantize to.
:return: pytorch_tpc. Target Platform Capability object with quantization details.
"""
tp = mct.target_platform
default_quant_options = tp.get_default_quantization_config_options
quant_config = _get_op_quantization_configs(activation_n_bits)
default_configuration_options = tp.QuantizationConfigOptions(quant_config)
generated_tpc = tp.TargetPlatformModel(default_configuration_options, name='mixed_prec')
with generated_tpc:
opset_config_options: Dict[str, Optional[tp.QuantizationConfigOptions]] = {
"NoQuantization": default_quant_options().clone_and_edit(quantization_preserving=True),
"FullyConnected": default_quant_options().clone_and_edit_weight_attribute(
weights_per_channel_threshold=False),
"L2Normalization": default_quant_options().clone_and_edit(fixed_zero_point=0, fixed_scale=1 / 128),
"LogSoftmax": default_quant_options().clone_and_edit(fixed_zero_point=127, fixed_scale=16 / 256),
"Tanh": default_quant_options().clone_and_edit(fixed_zero_point=0, fixed_scale=1 / 128),
"Softmax": default_quant_options().clone_and_edit(fixed_zero_point=-128, fixed_scale=1 / 256),
"Logistic": default_quant_options().clone_and_edit(fixed_zero_point=-128, fixed_scale=1 / 256),
"Relu": None,
"Elu": None,
"BatchNorm": None,
"BiasAdd": None,
"Squeeze": default_quant_options().clone_and_edit(quantization_preserving=True),
}
for opset_name, config_option in opset_config_options.items():
tp.OperatorsSet(opset_name, config_option)
conv2d = tp.OperatorsSet("Conv2d")
add = tp.OperatorsSet("Add")
tp.Fusing([conv2d, add])
pytorch_tpc = tp.TargetPlatformCapabilities(generated_tpc, name='mixed_prec', version='v1')
with pytorch_tpc:
opset_to_layers_dict: Dict[str, Tuple[List[Any, Optional[Dict[str, mct.defaultdict.DefaultDict]]]]] = {
"NoQuantization": (
[
torch.nn.AvgPool2d,
torch.nn.functional.avg_pool2d,
torch.cat,
torch.concat,
torch.nn.MaxPool2d,
torch.nn.functional.max_pool2d,
torch.mul,
torch.multiply,
torch.reshape,
tp.LayerFilterParams(torch.nn.functional.interpolate, mode='bilinear'),
torch.nn.ZeroPad2d,
torch.gather,
torch.transpose,
torch.maximum,
torch.max,
torch.minimum,
torch.min,
torch.nn.functional.pad,
torch.select,
torch.unbind
],
None
),
"FullyConnected": (
[torch.nn.Linear, torch.nn.functional.linear],
{
"kernel_attr": mct.defaultdict.DefaultDict(default_value="weight"),
"bias_attr": mct.defaultdict.DefaultDict(default_value="bias")
}
),
"LogSoftmax": ([torch.nn.LogSoftmax], None),
"Tanh": ([torch.nn.Tanh, torch.nn.functional.tanh], None),
"Softmax": ([torch.nn.Softmax, torch.nn.functional.softmax], None),
"Logistic": ([torch.nn.Sigmoid, torch.nn.functional.sigmoid], None),
"Conv2d": ([torch.nn.Conv2d, torch.nn.functional.conv2d], None),
"Relu": (
[
torch.relu,
torch.nn.ReLU,
torch.nn.ReLU6,
torch.nn.functional.relu,
torch.nn.functional.relu6,
tp.LayerFilterParams(torch.nn.Hardtanh, min_val=0, max_val=6),
tp.LayerFilterParams(torch.nn.functional.hardtanh, min_val=0, max_val=6)
],
None
),
"Elu": ([torch.nn.ELU, torch.nn.functional.elu], None),
"BatchNorm": ([torch.nn.BatchNorm2d, torch.nn.functional.batch_norm], None),
"Squeeze": ([torch.squeeze], None),
"Add": ([operator.add, torch.add], None),
}
for opset_name, (layers, attr_mapping) in opset_to_layers_dict.items():
tp.OperationsSetToLayers(opset_name, layers, attr_mapping)
return pytorch_tpc