Source code for afe.core.mixed_precision.interface

#########################################################
# Copyright (C) 2024 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Ljubomir Papuga
#########################################################
import torch

from typing import Callable, Iterable

from afe.apis.defines import InputValues

from afe.core.mixed_precision.annotation import annotate_model
from afe.core.mixed_precision.mixed_precision_search import (
    layer_demotion_search, quantize_model, export_torch_fx_to_onnx
)


[docs] def mixed_precision_analysis(fx_mod: torch.nn.Module, calibration_data: Iterable[InputValues], accuracy_metric: Callable, target_accuracy: float, annotated_onnx_filename: str) -> bool: """ Implements the mixed-precision quantization algorithm. TODO: Implementation details. :param fx_mod: Torch representation of the model. :param calibration_data: Data used in calibration. :param accuracy_metric: One-parameter function used to produce the accuracy metric. :param target_accuracy: Value used to determine the end of mixed precision search algorithm. :param annotated_onnx_filename: File path to which the annotated ONNX model is to be written. :return: bool. If the target accuracy is less than the 16 bit accuracy, saves an ONNX model containing precision annotations to the annotated_onnx_filename path and returns True else returns False. """ model_8_bit, model_16_bit = quantize_model(fx_mod, calibration_data) print("Torch FX int8 model accuracy: ", accuracy_metric(model_8_bit)) print("Torch FX int16 model accuracy: ", accuracy_metric(model_16_bit)) promotion_list = layer_demotion_search(model_8_bit, model_16_bit, accuracy_metric, target_accuracy) if promotion_list: export_torch_fx_to_onnx(model_8_bit, annotated_onnx_filename, calibration_data) annotate_model(promotion_list, annotated_onnx_filename) return True else: return False