Source code for afe.ir.bias_correction

#########################################################
# Copyright (C) 2024 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Christopher Rodrigues
#########################################################
"""
Bias correction algorithms.  These algorithms correct for
unwanted bias in convolution or matrix multiply that is caused
by quantization.
"""
import numpy as np

import ml_kernels.math_helpers
from afe.ir.defines import Quantization
from afe.ir.quantization_utils import dequantize


[docs] def prepare_input_mean(calibration_input: list[np.ndarray], quantization: Quantization | None) -> np.ndarray: """ Dequantize a set of calibration samples and calculate their per-channel mean value. This function's intended purpose is to calculate convolution's mean input value for iterative bias correction. Args: calibration_input: Set of quantized calibration samples quantization: Quantization of the input data. If None, the input data is not quantized. Returns: Dequantized mean value. It is a 1D array. """ input_data = np.concatenate(calibration_input, axis=0) if quantization is not None: dequantized = dequantize(input_data, 1 / quantization.scale, quantization.zero_point) else: assert ml_kernels.math_helpers.is_float_type(input_data.dtype) dequantized = input_data.astype(np.float32) return np.mean(dequantized, axis=(0, 1, 2))
[docs] class BiasCorrector: """ Abstract base class of a bias correction algorithm. The constructor may take parameters that are used for bias correction, such as calibration data. """ def __init__(self): raise NotImplementedError("Class is abstract")
[docs] def calculate(self, weights: np.ndarray, fake_quantized_weights: np.ndarray) -> np.ndarray | None: """ Calculate bias correction. Returns a floating-point value that should be added to the convolution's bias before it is quantized. :param weights: Floating-point weight tensor. :param fake_quantized_weights: Fake quantized weight tensor. It is the result of quantizing and then dequantizing the weights. :return: Bias correction value """ raise NotImplementedError("Method is abstract")
[docs] class NullBiasCorrector(BiasCorrector): """ No bias correction. """ def __init__(self): pass
[docs] def calculate(self, weights: np.ndarray, fake_quantized_weights: np.ndarray) -> np.ndarray | None: """ Do no bias correction. """ return None
[docs] class MeanBiasCorrector(BiasCorrector): """ Bias correction based on the observed mean value of the input activation tensor. Statistics correction as described in https://arxiv.org/abs/1906.04721, Section 4.2 :param input_mean: Observed per-channel mean value of the convolution's input. A one-dimensional tensor. """ def __init__(self, input_mean: np.ndarray): assert len(input_mean.shape) == 1, "input_mean must be one-dimensional" self._input_mean = input_mean
[docs] def calculate(self, weights: np.ndarray, fake_quantized_weights: np.ndarray) -> np.ndarray | None: assert weights.shape == fake_quantized_weights.shape num_spatial_dimensions = len(weights.shape) - 3 groups = weights.shape[-2] # Compute a correction to the weight tensor. # The correction is -1 times the estimated quantization error. The factor of -1 comes from this subtraction. diff_weights = weights - fake_quantized_weights # Sum the error in the spatial dimensions, only distinguish channels. eps_weights = np.sum(diff_weights, axis=tuple(range(num_spatial_dimensions))) assert len(eps_weights.shape) == 3 # Compute matrix-vector product on each group. The result is the correction to the output, # given the mean activation value. input_mean_groups = np.split(self._input_mean, groups, axis=0) eps_weights_groups = eps_weights.transpose((1, 0, 2)) assert len(eps_weights_groups) == len(input_mean_groups) quant_bias_correction_terms = \ [np.transpose(np.matmul(mu, eps)) for mu, eps in zip(input_mean_groups, eps_weights_groups)] return np.concatenate(quant_bias_correction_terms, axis=0)