#########################################################
# Copyright (C) 2024 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Ljubomir Papuga
#########################################################
import torch
from typing import Callable, Iterable
from afe.apis.defines import InputValues
from afe.core.mixed_precision.annotation import annotate_model
from afe.core.mixed_precision.mixed_precision_search import (
layer_demotion_search, quantize_model, export_torch_fx_to_onnx
)
[docs]
def mixed_precision_analysis(fx_mod: torch.nn.Module, calibration_data: Iterable[InputValues],
accuracy_metric: Callable, target_accuracy: float, annotated_onnx_filename: str) -> bool:
"""
Implements the mixed-precision quantization algorithm.
TODO: Implementation details.
:param fx_mod: Torch representation of the model.
:param calibration_data: Data used in calibration.
:param accuracy_metric: One-parameter function used to produce the accuracy metric.
:param target_accuracy: Value used to determine the end of mixed precision search algorithm.
:param annotated_onnx_filename: File path to which the annotated ONNX model is to be written.
:return: bool. If the target accuracy is less than the 16 bit accuracy, saves an ONNX model containing precision
annotations to the annotated_onnx_filename path and returns True else returns False.
"""
model_8_bit, model_16_bit = quantize_model(fx_mod, calibration_data)
print("Torch FX int8 model accuracy: ", accuracy_metric(model_8_bit))
print("Torch FX int16 model accuracy: ", accuracy_metric(model_16_bit))
promotion_list = layer_demotion_search(model_8_bit, model_16_bit, accuracy_metric, target_accuracy)
if promotion_list:
export_torch_fx_to_onnx(model_8_bit, annotated_onnx_filename, calibration_data)
annotate_model(promotion_list, annotated_onnx_filename)
return True
else:
return False