#########################################################
# Copyright (C) 2024 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Christopher Rodrigues
#########################################################
"""
Bias correction algorithms. These algorithms correct for
unwanted bias in convolution or matrix multiply that is caused
by quantization.
"""
import numpy as np
import ml_kernels.math_helpers
from afe.ir.defines import Quantization
from afe.ir.quantization_utils import dequantize
[docs]
class BiasCorrector:
"""
Abstract base class of a bias correction algorithm.
The constructor may take parameters that are used for bias correction, such as calibration data.
"""
def __init__(self):
raise NotImplementedError("Class is abstract")
[docs]
def calculate(self, weights: np.ndarray, fake_quantized_weights: np.ndarray) -> np.ndarray | None:
"""
Calculate bias correction. Returns a floating-point value that should be added to the
convolution's bias before it is quantized.
:param weights: Floating-point weight tensor.
:param fake_quantized_weights: Fake quantized weight tensor. It is the result of
quantizing and then dequantizing the weights.
:return: Bias correction value
"""
raise NotImplementedError("Method is abstract")
[docs]
class NullBiasCorrector(BiasCorrector):
"""
No bias correction.
"""
def __init__(self):
pass
[docs]
def calculate(self, weights: np.ndarray, fake_quantized_weights: np.ndarray) -> np.ndarray | None:
"""
Do no bias correction.
"""
return None
[docs]
class MeanBiasCorrector(BiasCorrector):
"""
Bias correction based on the observed mean value of the input activation tensor.
Statistics correction as described in https://arxiv.org/abs/1906.04721, Section 4.2
:param input_mean: Observed per-channel mean value of the convolution's input.
A one-dimensional tensor.
"""
def __init__(self, input_mean: np.ndarray):
assert len(input_mean.shape) == 1, "input_mean must be one-dimensional"
self._input_mean = input_mean
[docs]
def calculate(self, weights: np.ndarray, fake_quantized_weights: np.ndarray) -> np.ndarray | None:
assert weights.shape == fake_quantized_weights.shape
num_spatial_dimensions = len(weights.shape) - 3
groups = weights.shape[-2]
# Compute a correction to the weight tensor.
# The correction is -1 times the estimated quantization error. The factor of -1 comes from this subtraction.
diff_weights = weights - fake_quantized_weights
# Sum the error in the spatial dimensions, only distinguish channels.
eps_weights = np.sum(diff_weights, axis=tuple(range(num_spatial_dimensions)))
assert len(eps_weights.shape) == 3
# Compute matrix-vector product on each group. The result is the correction to the output,
# given the mean activation value.
input_mean_groups = np.split(self._input_mean, groups, axis=0)
eps_weights_groups = eps_weights.transpose((1, 0, 2))
assert len(eps_weights_groups) == len(input_mean_groups)
quant_bias_correction_terms = \
[np.transpose(np.matmul(mu, eps)) for mu, eps in zip(input_mean_groups, eps_weights_groups)]
return np.concatenate(quant_bias_correction_terms, axis=0)