Source code for afe.core.mixed_precision.mixed_precision_search

#########################################################
# Copyright (C) 2024 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Shreyas Kera
#########################################################
from typing import Callable, Iterable, List, Tuple, Optional

import copy
import random

import torch

from afe.apis.defines import InputValues

import model_compression_toolkit as mct

from afe.core.mixed_precision.config import get_core_config, get_model_target_platform_capabilities


[docs] def get_activation_holders(model: torch.nn.Module) -> List[str]: """ Gets list of activation holders in quantized model out of mct. :param model: Torch representation of the model. :return: List. List of activation holders in quantized model. """ act_layers = [] for module in model.named_modules(): if 'activation_holder_quantizer' in module[0]: act_layers.append(module[0]) return act_layers
[docs] def demote_iter(model: torch.nn.Module, accuracy_metric: Callable, search_layers: List[int], search_accuracy_threshold: float, acc_8: float, start_index: int, act_layers: List[str], save_module_8: List[torch.nn.Module]) -> Tuple[torch.nn.Module, float, int, int, List[int]]: """ One iteration of a demotion search. The iteration finds one layer that can be demoted to int8 without dropping the accuracy below search_accuracy_threshold. It returns a shortened search_layers list, as well as the index to start the search on the next call to this function. If no layer is found that meets the criterion, it returns the "next best" layer, i.e. the layer whose demotion results in the smallest accuracy drop. Note that the mixed-precision model produced by this method of searching is not unique: there can be more than one mix of int8/int16 that meets the accuracy threshold. This is an artifact of the search method: If we were to search in a strict order wherein we demote one “best” layer per each pass through the search list, we would get a unique mixed-precision solution and the accuracy would decrease strictly monotonically as a function of passes, but the search would be a prohibitive O(n!). So, we do a compromise, where in each pass through the search list, we demote layers in batches as long as we meet the target accuracy. However, because the layers are interdependent, a single pass does not demote all possible layers. This means (i) we require multiple passes, (ii) the accuracy does not decrease monotonically and (iii) the 8/16 mix at the end is not unique. However, - the mixed-precision model is guaranteed to meet the accuracy target, and - we find empirically that the last few layers (the ones what drop the accuracy the most when demoted) are the same from run to run, so the 8/16 mix is similar even though not identical. :param model: Current state of quantized int16 model with certain layers demoted to int8. :param accuracy_metric: Function used to produce the accuracy metric. :param search_layers: List of indexes of layers left to search. :param search_accuracy_threshold: Threshold used to exit demote_iter. This is a value between the 16-bit model accuracy and the overall target_acc :param acc_8: Accuracy of 8 bit quantized model. :param start_index: Index to start search. :param act_layers: List of activation holders in quantized model. :param save_module_8: List of saved layers from 8 bit model. :return: model: Model with selected layer demoted to 8 bit, acc_new: accuracy of new model, demote_layer_index: index of demoted layer, next_start_index: starting index for next iteration, search_layers: indexes of layers left to search with current layer removed. """ end_index = len(search_layers) # Current index into search_layers for search curr_index = start_index save_module = [] for i in range(len(act_layers)): act_layer = act_layers[i] save_module.append(getattr(model, act_layer)) curr_acc = acc_8 # Index into search_layers for the index with the highest new accuracy curr_best_index = 0 # Initialize acc_new if curr_index >= end_index acc_new = curr_acc while curr_index < end_index: act_layer = act_layers[search_layers[curr_index]] setattr(model, act_layer, save_module_8[search_layers[curr_index]]) acc_new = accuracy_metric(model) setattr(model, act_layer, save_module[search_layers[curr_index]]) if acc_new >= curr_acc: curr_acc = acc_new curr_best_index = curr_index if acc_new > search_accuracy_threshold: next_start_index = curr_index break curr_index += 1 else: next_start_index = 0 # Index into act_layers demote_layer_index = search_layers[curr_best_index] demote_layer = act_layers[demote_layer_index] setattr(model, demote_layer, save_module_8[demote_layer_index]) search_layers.pop(curr_best_index) if curr_best_index == end_index - 1: next_start_index = 0 return model, acc_new, demote_layer_index, next_start_index, search_layers
[docs] def quantize_model(model: torch.nn.Module, calibration_data: Iterable[InputValues]) -> Tuple[torch.nn.Module, torch.nn.Module]: """ Quantize torch fx model. :param model: Torch representation of the model. :param calibration_data: Data used in calibration. :return: model_8_bit, model_16_bit. mct quantized 8-bit and 16-bit models. """ def representative_data_gen(): return iter(calibration_data) mct_ptq_config = get_core_config() target_platform_cap = get_model_target_platform_capabilities(activation_n_bits=8) model_8_bit, _ = mct.ptq.pytorch_post_training_quantization( in_module=model, representative_data_gen=representative_data_gen, target_platform_capabilities=target_platform_cap, core_config=mct_ptq_config ) target_platform_cap = get_model_target_platform_capabilities(activation_n_bits=16) model_16_bit, _ = mct.ptq.pytorch_post_training_quantization( in_module=model, representative_data_gen=representative_data_gen, target_platform_capabilities=target_platform_cap, core_config=mct_ptq_config ) return model_8_bit, model_16_bit
[docs] def export_torch_fx_to_onnx(quantized_model: torch.nn.Module, annotated_onnx_filename: str, calibration_data: Iterable[InputValues]) -> None: """ Export the model with activation holders to onnx. :param quantized_model: Model quantized by mct with Activation Holders. :param annotated_onnx_filename: File path to which the annotated ONNX model is to be written. :param calibration_data: Data used in calibration. :return: None. Writes the output ONNX file containing activation holders to the annotated_onnx_filename path. """ def representative_data_gen(): return iter(calibration_data) mct.exporter.pytorch_export_model(model=quantized_model, save_model_path=annotated_onnx_filename, repr_dataset=representative_data_gen, onnx_opset_version=17)