#########################################################
# Copyright (C) 2024 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Shreyas Kera
#########################################################
from typing import Callable, Iterable, List, Tuple, Optional
import copy
import random
import torch
from afe.apis.defines import InputValues
import model_compression_toolkit as mct
from afe.core.mixed_precision.config import get_core_config, get_model_target_platform_capabilities
[docs]
def get_activation_holders(model: torch.nn.Module) -> List[str]:
"""
Gets list of activation holders in quantized model out of mct.
:param model: Torch representation of the model.
:return: List. List of activation holders in quantized model.
"""
act_layers = []
for module in model.named_modules():
if 'activation_holder_quantizer' in module[0]:
act_layers.append(module[0])
return act_layers
[docs]
def demote_iter(model: torch.nn.Module, accuracy_metric: Callable,
search_layers: List[int], search_accuracy_threshold: float, acc_8: float,
start_index: int, act_layers: List[str],
save_module_8: List[torch.nn.Module]) -> Tuple[torch.nn.Module, float, int, int, List[int]]:
"""
One iteration of a demotion search.
The iteration finds one layer that can be demoted to int8 without dropping the accuracy below
search_accuracy_threshold. It returns a shortened search_layers list, as well as the index to start the
search on the next call to this function. If no layer is found that meets the criterion, it returns the
"next best" layer, i.e. the layer whose demotion results in the smallest accuracy drop.
Note that the mixed-precision model produced by this method of searching is not unique: there can be
more than one mix of int8/int16 that meets the accuracy threshold.
This is an artifact of the search method: If we were to search in a strict order wherein we demote one
“best” layer per each pass through the search list, we would get a unique mixed-precision solution and
the accuracy would decrease strictly monotonically as a function of passes, but the search would be a
prohibitive O(n!). So, we do a compromise, where in each pass through the search list, we demote layers
in batches as long as we meet the target accuracy. However, because the layers are interdependent, a
single pass does not demote all possible layers.
This means
(i) we require multiple passes,
(ii) the accuracy does not decrease monotonically and
(iii) the 8/16 mix at the end is not unique.
However,
- the mixed-precision model is guaranteed to meet the accuracy target, and
- we find empirically that the last few layers (the ones what drop the accuracy the most when demoted)
are the same from run to run, so the 8/16 mix is similar even though not identical.
:param model: Current state of quantized int16 model with certain layers demoted to int8.
:param accuracy_metric: Function used to produce the accuracy metric.
:param search_layers: List of indexes of layers left to search.
:param search_accuracy_threshold: Threshold used to exit demote_iter. This is a value between the
16-bit model accuracy and the overall target_acc
:param acc_8: Accuracy of 8 bit quantized model.
:param start_index: Index to start search.
:param act_layers: List of activation holders in quantized model.
:param save_module_8: List of saved layers from 8 bit model.
:return: model: Model with selected layer demoted to 8 bit, acc_new: accuracy of new model,
demote_layer_index: index of demoted layer, next_start_index: starting index for next iteration,
search_layers: indexes of layers left to search with current layer removed.
"""
end_index = len(search_layers)
# Current index into search_layers for search
curr_index = start_index
save_module = []
for i in range(len(act_layers)):
act_layer = act_layers[i]
save_module.append(getattr(model, act_layer))
curr_acc = acc_8
# Index into search_layers for the index with the highest new accuracy
curr_best_index = 0
# Initialize acc_new if curr_index >= end_index
acc_new = curr_acc
while curr_index < end_index:
act_layer = act_layers[search_layers[curr_index]]
setattr(model, act_layer, save_module_8[search_layers[curr_index]])
acc_new = accuracy_metric(model)
setattr(model, act_layer, save_module[search_layers[curr_index]])
if acc_new >= curr_acc:
curr_acc = acc_new
curr_best_index = curr_index
if acc_new > search_accuracy_threshold:
next_start_index = curr_index
break
curr_index += 1
else:
next_start_index = 0
# Index into act_layers
demote_layer_index = search_layers[curr_best_index]
demote_layer = act_layers[demote_layer_index]
setattr(model, demote_layer, save_module_8[demote_layer_index])
search_layers.pop(curr_best_index)
if curr_best_index == end_index - 1:
next_start_index = 0
return model, acc_new, demote_layer_index, next_start_index, search_layers
[docs]
def layer_demotion_search(model_8_bit: torch.nn.Module, model_16_bit: torch.nn.Module,
accuracy_metric: Callable[[torch.nn.Module], float], target_acc: float) \
-> Optional[List[str]]:
"""
Get list of layers to promote to 16 bit based on demotion search.
Note: The accuracy does not decrease monotonically as layers are demoted.
Empirically we find that demoting an individual layer to int8 sometimes results in
a decrease in accuracy below search_acc_threshold, but the accuracy recovers when a
neighboring layer is also quantized to int8. We want to avoid terminating on these
local minima, so we continue through all the layers, and then search backward through
the mixed models until we find the one with largest number of layers in int8 that still
meets the target accuracy.
TODO: Implementation details.
:param model_8_bit: 8 bit mct quantized model.
:param model_16_bit: 16 bit mct quantized model.
:param accuracy_metric: Function used to produce the accuracy metric.
:param target_acc: Value used to determine the end of mixed precision search algorithm.
:return: promotion_list. List of layer names to promote to 16 bit.
"""
# This is a parameter that should be moved to some sort of mixed_precision_config
# It is a value between 0 and 1 that is used to set the accuracy threshold for the
# demotion_iter loop. Close to 0, the results are repeatable, but the search is slow.
# Closer to 1, the search is faster, but the resulting int16/int8 mix may sometimes
# be suboptimal
demote_iter_threshold_multiplier = 0.5
acc_16 = accuracy_metric(model_16_bit)
if acc_16 < target_acc:
return None
if accuracy_metric(model_8_bit) > target_acc:
return []
search_acc_threshold = acc_16 - demote_iter_threshold_multiplier * (acc_16 - target_acc)
acc_8 = accuracy_metric(model_8_bit)
act_layers = get_activation_holders(model_8_bit)
model = copy.deepcopy(model_16_bit)
count_act_layers = len(act_layers)
search_layers = list(range(count_act_layers))
save_module_8 = []
for i in range(count_act_layers):
act_layer = act_layers[i]
save_module_8.append(getattr(model_8_bit, act_layer))
demote_accuracy = []
demote_layers = []
start_index = 0
rng = random.Random(1)
while len(search_layers) >= 1:
if start_index == 0:
rng.shuffle(search_layers)
# TODO: Currently, when the demote_iter searches to the end of search_layers, it
# returns the curr_best_index computed only over the last segment of search_layers
# The results may be more reliable if it returns the best_index computed over the whole list
model, new_acc, demote_index, start_index, search_layers = demote_iter(
model, accuracy_metric, search_layers, search_acc_threshold, acc_8, start_index, act_layers, save_module_8
)
demote_accuracy.append(new_acc)
demote_layers.append(act_layers[demote_index])
accuracies = demote_accuracy[::-1]
layers = demote_layers[::-1]
for i, accuracy in enumerate(accuracies):
if accuracy > target_acc:
promotion_list = layers[:i]
return promotion_list
# If None of the accuracies is not enough, return int16 model
return layers
[docs]
def quantize_model(model: torch.nn.Module,
calibration_data: Iterable[InputValues]) -> Tuple[torch.nn.Module, torch.nn.Module]:
"""
Quantize torch fx model.
:param model: Torch representation of the model.
:param calibration_data: Data used in calibration.
:return: model_8_bit, model_16_bit. mct quantized 8-bit and 16-bit models.
"""
def representative_data_gen():
return iter(calibration_data)
mct_ptq_config = get_core_config()
target_platform_cap = get_model_target_platform_capabilities(activation_n_bits=8)
model_8_bit, _ = mct.ptq.pytorch_post_training_quantization(
in_module=model,
representative_data_gen=representative_data_gen,
target_platform_capabilities=target_platform_cap,
core_config=mct_ptq_config
)
target_platform_cap = get_model_target_platform_capabilities(activation_n_bits=16)
model_16_bit, _ = mct.ptq.pytorch_post_training_quantization(
in_module=model,
representative_data_gen=representative_data_gen,
target_platform_capabilities=target_platform_cap,
core_config=mct_ptq_config
)
return model_8_bit, model_16_bit
[docs]
def export_torch_fx_to_onnx(quantized_model: torch.nn.Module, annotated_onnx_filename: str,
calibration_data: Iterable[InputValues]) -> None:
"""
Export the model with activation holders to onnx.
:param quantized_model: Model quantized by mct with Activation Holders.
:param annotated_onnx_filename: File path to which the annotated ONNX model is to be written.
:param calibration_data: Data used in calibration.
:return: None. Writes the output ONNX file containing activation holders to the
annotated_onnx_filename path.
"""
def representative_data_gen():
return iter(calibration_data)
mct.exporter.pytorch_export_model(model=quantized_model,
save_model_path=annotated_onnx_filename,
repr_dataset=representative_data_gen,
onnx_opset_version=17)