Source code for afe.ir.node_observer

#########################################################
# Copyright (C) 2023 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Ljubomir Papuga
#########################################################
"""
Infrastructure adapting the ModelSDK utilization of PyTorch Observers.
"""
from typing import Tuple, Optional, List, Union
import contextlib
import numpy as np
import torch
from torch.ao.quantization.observer import (
    ObserverBase, MinMaxObserver, HistogramObserver, MovingAverageMinMaxObserver, PerChannelMinMaxObserver
)
from onnxruntime.quantization.calibrate import HistogramCollector

from afe.apis.defines import (
    CalibrationMethod, HistogramEntropyMethod, HistogramPercentileMethod, MinMaxMethod,
    MovingAverageMinMaxMethod, HistogramMSEMethod
)
from afe.ir.defines import DataValue, TensorValue, Quantization, NodeReporter
from afe.ir.torch_utils import convert_numpy_to_torch, numpy_tensor_to_scalar
from sima_utils.logging.sima_logger import UserFacingException


class _ObserverWrapperBase:

    def update(self, x: np.ndarray):
        raise NotImplementedError("Abstract class")

    def compute_scale_and_zp(self, qrange: Optional[Tuple[int, int]] = None,
                             error_reporter: Optional[NodeReporter] = None) \
            -> Union[Tuple[float, int], Tuple[List[float], List[int]]]:
        """
        Computes quantization scale and zero point using statistics that were collected
        by prior calls of update.

        :param qrange: The numeric range to quantize for.  If None, the range is determined
           according to the class's methods.
        """
        raise NotImplementedError("Abstract class")

    def asymmetry(self):
        raise NotImplementedError("Abstract class")

    def min_max(self):
        raise NotImplementedError("Abstract class")


class _PyTorchObserverWrapper(_ObserverWrapperBase):
    _observer: ObserverBase
    _asymmetry: bool

    def __init__(self, calibration_method: CalibrationMethod, asymmetry: bool):

        self._asymmetry = asymmetry
        qscheme = torch.per_tensor_affine if asymmetry else torch.per_tensor_symmetric
        # UniformQuantizationObserverBase which is inherited by both MinMaxObserver
        # and HistogramObserver support only torch.qint8 or torch.quint8
        dtype = torch.qint8

        # Current implementation limits quantization range to -127, 127 in symmetric case
        qmin = -127 if not asymmetry else None
        qmax = 127 if not asymmetry else None

        if isinstance(calibration_method, MinMaxMethod):
            self._observer = MinMaxObserver(dtype=dtype, qscheme=qscheme, quant_min=qmin, quant_max=qmax)
        elif isinstance(calibration_method, MovingAverageMinMaxMethod):
            self._observer = MovingAverageMinMaxObserver(dtype=dtype, qscheme=qscheme, quant_min=qmin, quant_max=qmax)
        elif isinstance(calibration_method, HistogramMSEMethod):
            num_bins = calibration_method.num_bins
            self._observer = HistogramObserver(bins=num_bins, dtype=dtype, qscheme=qscheme,
                                               quant_min=qmin, quant_max=qmax)

    def update(self, x: np.ndarray):
        self._observer.forward(convert_numpy_to_torch(x))

    def compute_scale_and_zp(self, qrange: Optional[Tuple[int, int]] = None,
                             error_reporter: Optional[NodeReporter] = None) -> Tuple[float, int]:
        """
        Computes scale and zp using the statistics that the observer has collected.

        :param qrange: The numeric range to quantize for.  If None, the range is determined
           based on how this class instance was initialized.
        """
        sc, zp = self._observer.calculate_qparams()
        sc = np.nan_to_num(sc.numpy())
        zp = np.nan_to_num(zp.numpy())

        # Dump statistics from the observer
        internal_debug = False
        if internal_debug and error_reporter is not None:
            range_min, range_max = self._observer.min_val, self._observer.max_val
            observer_msg = "Calibration - "
            if isinstance(self._observer, HistogramObserver):
                observer_msg += "mse: "
                clip_min, clip_max = self._observer._non_linear_param_search()
            else:
                observer_msg += "min_max: "
                clip_min, clip_max = self._observer.min_val, self._observer.max_val
            observer_msg += (f"Activation range is [{range_min}, {range_max}], "
                             f"computed clip interval is [{clip_min}, {clip_max}]")
            error_reporter.debug(observer_msg)

        # Pytorch computes scale and zero point for int8.
        # If qrange was specified, convert to scale and zero point for qrange.
        if qrange is not None:
            sc, zp = _pytorch_observer_range_change(sc, zp, qrange, self._asymmetry)

        sc = numpy_tensor_to_scalar(sc)
        zp = numpy_tensor_to_scalar(zp)

        qrange = qrange if qrange is not None else (self._observer.quant_min, self._observer.quant_max)
        min_val, _ = self.min_max()
        sc, zp = _check_and_adjust_too_low_scales(sc, zp, qrange, min_val, self._asymmetry, error_reporter)
        return sc, zp

    def asymmetry(self):
        return self._asymmetry

    def min_max(self) -> Tuple[float, float]:
        return self._observer.min_val.numpy().item(), self._observer.max_val.numpy().item()


class _OnnxObserverWrapper(_ObserverWrapperBase):
    _observer: HistogramCollector
    _asymmetry: bool
    _qmin: int
    _qmax: int

    def __init__(self, method: CalibrationMethod, asymmetry: bool, percentile: Optional[float] = None):

        self._asymmetry = asymmetry

        if isinstance(method, HistogramEntropyMethod):
            observer_type = "entropy"
            num_quantized_bins = 256
        elif isinstance(method, HistogramPercentileMethod):
            observer_type = "percentile"
            # Number of quantized bins is only used for entropy.
            num_quantized_bins = None
        else:
            raise ValueError("Only entropy and percentile are valid for this observer wrapper.")

        num_bins = method.num_bins
        self._observer = HistogramCollector(observer_type, symmetric=not asymmetry,
                                            num_bins=num_bins, num_quantized_bins=num_quantized_bins,
                                            percentile=percentile)
        self._qmin = -127 if not asymmetry else -128
        self._qmax = 127

    def update(self, x: np.ndarray):
        """
        Collect values in the observer. Name in the dictionary doesn't mean anything.
        Onnxruntime HistogramCollector collect function takes the tensor values as a dictionary,
        and that is why made up name is added.
        """
        # This will block print statements from inside onnxruntime.
        with contextlib.redirect_stdout(None):
            self._observer.collect({"input1": x})

    def compute_scale_and_zp(self, qrange: Optional[Tuple[int, int]] = None,
                             error_reporter: Optional[NodeReporter] = None) -> Tuple[float, int]:
        """
        Computes scale and zp using the observer. Note that onnxruntime takes dictionaries with tensor
        names and values. As we passed name "input1" in update method, here we retrieve it the same way.
        Names in the dictionary don't mean anything.

        Note for onnxruntime 1.15.0
        The function compute_collection_result() may return one of two possible types: np.float or np.float32.
        When SiMa Model is saved, yaml writer only takes native float type which is np.float. Hence, we need
        to add forced cast to min and max values returned from onnx runtime.

        :param qrange: The numeric range to quantize for.  If None, the range is determined
           based on how this class instance was initialized.
        """
        quant_range = qrange if qrange is not None else (self._qmin, self._qmax)
        asymmetry = quant_range[0] != -quant_range[1]
        min_val, max_val = self.min_max()

        # Dump statistics from the observer
        internal_debug = False
        if internal_debug and error_reporter is not None:
            hist = list(self._observer.get_histogram_dict().values())[0]
            # histogram[0:4] = [hist, hist_edges, min, max]
            range_min, range_max = [hist[2], hist[3]]
            observer_msg = f"Calibration - {self._observer.method}: "
            observer_msg += (f"Activation range is [{range_min}, {range_max}], "
                             f"computed clip interval is [{min_val}, {max_val}]")
            error_reporter.debug(observer_msg)

        # Adjust min and max such that 0 is included in the range. This is
        # required to make sure zero can be represented by the quantization data
        # type (i.e. to make sure qmin <= zero_point <= qmax).
        # It also helps quantization of a scalar constant that is not folded away.
        min_val = min(min_val, 0)
        max_val = max(max_val, 0)

        # Detect unusual ranges of non-zero tensors that will produce zero scales for quantization
        if min_val >= max_val:
            calibration_method = self._observer.method
            if calibration_method == 'percentile':
                calibration_method += f" ({self._observer.percentile}%)"
            hist = list(self._observer.get_histogram_dict().values())[0]
            data_range = [hist[2], hist[3]]
            if data_range[0] != 0 or data_range[1] != 0:  #  Escape zero tensors
                non_zero_bins = np.nonzero(hist[0])
                non_zero_counts = hist[0][non_zero_bins]
                total_count = sum(non_zero_counts)
                th = non_zero_counts[0] * 100 / total_count
                calibration_error = (f"Unusual distribution detected for tensor with {total_count} elements!\n"
                                     f"Histogram non-zero bins = {non_zero_bins}\ndistribution = {non_zero_counts}\n"
                                     f"Invalid clip interval [{min_val}, {min_val}] returned by calibration method "
                                     f"{calibration_method} for data range {data_range}.\n"
                                     f"The first non-zero bin is already {th}\n"
                                     "Please try different calibration methods "
                                     f"or increase percentile value over {th} for percentile method.")
                raise UserFacingException(calibration_error)        

        if not asymmetry:
            absmax = max(abs(min_val), abs(max_val))
            min_val = -absmax
            max_val = +absmax

        sc = (max_val - min_val) / float(quant_range[1] - quant_range[0])
        zp = round(quant_range[0] - min_val / sc) if sc != 0 else 0

        # In case scale is too low we will use a small value and recalculate zp
        sc, zp = _check_and_adjust_too_low_scales(sc, zp, quant_range, min_val, asymmetry, error_reporter)

        return sc, zp

    def asymmetry(self):
        return self._asymmetry

    def min_max(self):
        """
        Get min and max values.
        """
        # This line disables print statements from inside onnxruntime.
        with contextlib.redirect_stdout(None):
            min_max = self._observer.compute_collection_result()
            input_keys = [key for key in min_max.keys()]
            min_val, max_val = min_max[input_keys[0]]
            min_val = float(min_val)
            max_val = float(max_val)
        return min_val, max_val


class _MeanObserver(_ObserverWrapperBase):
    """
    Observer set up to track mean values.
    """

    _mean: Optional[np.ndarray] = None
    _alpha: float = 0.99

    def compute_scale_and_zp(self, qrange: Optional[Tuple[int, int]] = None,
                             error_reporter: Optional[NodeReporter] = None) -> Tuple[float, int]:
        raise RuntimeError("Instance of MeanObserver shouldn't run this.")

    def asymmetry(self):
        raise RuntimeError("Instance of MeanObserver shouldn't run this.")

    def min_max(self):
        raise RuntimeError("Instance of MeanObserver shouldn't run this.")

    def get_mean(self):
        return self._mean

    def update(self, x: np.ndarray):
        n = x.shape[0]
        sproduct = np.product(x.shape[1:-1])
        c = x.shape[-1]
        mu = np.mean(np.reshape(x, [n, sproduct, c]), 1)
        if self._mean is None:
            self._mean = (1 - self._alpha) * mu
        else:
            self._mean = self._alpha * self._mean + (1 - self._alpha) * mu


class _MinMaxPerChannelObserver(_ObserverWrapperBase):
    """
    Observer used for per channel calibration of values.
    It shouldn't be used for complete model calibration like other observers.
    """
    def __init__(self, asymmetry: bool, channel_axis: int):

        self._asymmetry = asymmetry
        # Current implementation limits quantization range to -127, 127 in symmetric case
        qmin = -127 if not asymmetry else None
        qmax = 127 if not asymmetry else None

        qscheme = torch.per_channel_affine if self._asymmetry else torch.per_channel_symmetric
        self._observer = PerChannelMinMaxObserver(ch_axis=channel_axis, dtype=torch.qint8, qscheme=qscheme,
                                                  quant_min=qmin, quant_max=qmax)

    def update(self, x: np.ndarray):
        self._observer.forward(torch.from_numpy(x))

    def compute_scale_and_zp(self, qrange: Optional[Tuple[int, int]] = None,
                             error_reporter: Optional[NodeReporter] = None) -> Tuple[List[float], List[int]]:
        sc, zp = self._observer.calculate_qparams()
        sc, zp = sc.numpy(), zp.numpy()

        # Dump statistics from the observer
        internal_debug = False
        if internal_debug and error_reporter is not None:
            observer_msg = "Calibration - min_max per channel"
            per_channel_range = np.dstack((self._observer.min_val.numpy(), self._observer.max_val.numpy()))
            observer_msg += f"Activation min and max for each channel is {per_channel_range}"
            error_reporter.debug(observer_msg)

        # Pytorch computes scale and zero point for int8.
        # If qrange was specified, convert to scale and zero point for qrange.
        if qrange is not None:
            sc, zp = _pytorch_observer_range_change(sc, zp, qrange, self._asymmetry)

        return sc, zp

    def asymmetry(self) -> bool:
        return self._asymmetry

    def min_max(self) -> Tuple[np.ndarray, np.ndarray]:
        return self._observer.min_val.numpy(), self._observer.max_val.numpy()


class _QDQObserver(_ObserverWrapperBase):
    """
    Observer set up to keep information about zero_point and scale, extracted from QDQ layers.
    """
    def __init__(self, zero_point: np.ndarray, scale: np.ndarray):
        assert zero_point.size > 0 and scale.size > 0, "Error: zero_point and scale must be a non-empty arrays"
        assert zero_point.size == scale.size, "Error: zero_point and scale must have same number of elements."
        assert np.issubsctype(zero_point.dtype, int), "Error: elements of zero_point must be of type int"
        assert np.issubsctype(scale.dtype, float), "Error: elements of scale must be of type float"

        self.zero_point = zero_point.flatten()
        self.scale = scale.flatten()

    def compute_scale_and_zp(self, qrange: Optional[Tuple[int, int]] = None,
                             error_reporter: Optional[NodeReporter] = None) \
            -> Union[Tuple[float, int], Tuple[np.ndarray, np.ndarray]]:

        zp = self.zero_point if self.zero_point.size != 1 else int(self.zero_point.item())
        sc = self.scale if self.scale.size != 1 else float(self.scale.item())

        # Node observer expects scale in ONNX/PyTorch representation
        return 1.0 / sc, zp

    def asymmetry(self):
        raise RuntimeError("Instance of QDQObserver shouldn't run this.")

    def min_max(self):
        raise RuntimeError("Instance of QDQObserver shouldn't run this.")

    def update(self, x: np.ndarray):
        # For QDQObserver, update should do nothing, since sc and zp are extracted from model's QDQ layers
        pass


def _pytorch_observer_range_change(sc: np.ndarray, zp: np.ndarray, qrange: Tuple[int, int], asymmetry: bool):
    assert qrange is not None
    assert isinstance(sc, np.ndarray) and sc.ndim == 1
    assert isinstance(zp, np.ndarray) and zp.ndim == 1
    # Calculate the floating-point range that Pytorch used
    i8min = -128 if asymmetry else -127
    i8max = 127
    rmin = sc * (i8min - zp)
    rmax = sc * (i8max - zp)

    # Calculate scale and zp for this floating-point range
    qmin, qmax = qrange
    sc = (rmax - rmin) / (qmax - qmin)
    zp = qmin - (rmin / sc)
    # Using python rounding to avoid numpy's round to even method.
    zp = np.array([round(x) for x in zp]).astype(np.int32)
    return sc, zp


def _check_and_adjust_too_low_scales(sc: float, zp: int, quant_range: Tuple[int, int], min_val: float, asymmetry: bool,
                                     error_reporter: Optional[NodeReporter] = None) -> Tuple[float, int]:
    """
    We impose a lower bound for scales. In case the scale is lower, we will recalculate.
    """
    if sc < torch.finfo(torch.float32).eps:
        original_scale, original_zp = sc, zp

        sc = torch.finfo(torch.float32).eps
        if asymmetry:
            zp = round(quant_range[0] - min_val / sc)
            zp = int(np.clip(zp, quant_range[0], quant_range[1]))
        else:
            zp = 0
        if error_reporter is not None:
            error_reporter.warn(f"Calculated scale is too low. It will be adjusted. Values for scale and zeropoint "
                                f"{original_scale}, {original_zp} are rewritten to {sc}, {zp}")
    return sc, zp


[docs] class NodeObserver: """ A module used for observing the statistics of the node's output data and calculation of quantization parameters. It uses the PyTorch and Onnx observer infrastructures to collect the statistics and calculate quantization parameters. Variable _observer is the instance of the PyTorch, Onnx or special Mean observer. Onnx and Pytorch are used for calibration and calculating scales and zero points, while Mean observer tracks mean statistics. Currently, MinMaxObserver, MovingAverageMinMaxObserver and HistogramMSE are supported with PyTorch. Onnxruntime is used for HistogramEntropy and HistogramPercentile observers. The observer flavor is determined by the calibration_method constructor argument. The symmetric/asymmetric quantization scheme is selected using the asymmetry constructor argument. :param calibration_method: Flavor of observer used in calibration. :param node_reporter: Node reporter for the node being observed. :param do_mean_estimation: Boolean determining whether we do mean estimation. If True, calibration methods are ignored and MeanObserver is used, which will simply calculate the mean. """ _observer: _ObserverWrapperBase _calibration_method: Optional[CalibrationMethod] _node_reporter: Optional[NodeReporter] = None def __init__(self, *, calibration_method: CalibrationMethod = MinMaxMethod(), asymmetry: bool = True, node_reporter: Optional[NodeReporter] = None, do_mean_estimation: bool = False, per_channel_min_max: bool = False, qdq_quantization: Optional[Quantization] = None): self._calibration_method = calibration_method if qdq_quantization: self._observer = _QDQObserver(np.array(qdq_quantization.zero_point), np.array(qdq_quantization.scale)) self._calibration_method = None elif do_mean_estimation: # calibration methods are ignored and MeanObserver is used, which will simply calculate the mean. # Unlike observers which calculate quantization parameters like scale and zero point, this just does mean. self._observer = _MeanObserver() elif per_channel_min_max: # Initialize per channel min max observer. It is not possible to use this as a general calibration method. # It is used in special cases as an intermediate observer. self._observer = _MinMaxPerChannelObserver(asymmetry, channel_axis=-1) elif isinstance(calibration_method, (MinMaxMethod, HistogramMSEMethod, MovingAverageMinMaxMethod)): self._observer = _PyTorchObserverWrapper(calibration_method, asymmetry) elif isinstance(calibration_method, (HistogramEntropyMethod, HistogramPercentileMethod)): percentile_value = calibration_method.percentile_value if isinstance(calibration_method, HistogramPercentileMethod) else None self._observer = _OnnxObserverWrapper(calibration_method, asymmetry, percentile=percentile_value) self._node_reporter = node_reporter
[docs] def update(self, x: np.ndarray): """ Updates the statistic of the node's output data. :param x: np.ndarray. The node's output data. """ self._observer.update(x)
[docs] def calculate_quantization(self, qrange: Optional[Tuple[int, int]] = None) -> DataValue[Quantization]: """ Calculates the output Quantization parameters of the node. It calls the ObserverWrapperBase.calculate_scale_and_zp method. When constructing Quantization instance, scale is modified to follow SiMa quantization scheme (reciprocal value is used). :param qrange: The numeric range to quantize for. If None, the range is determined based on how the observer was initialized. :return: DataValue[Quantization]. Quantization parameters of the node's output wrapped in the DataValue instance. Since currently there is no support for nodes producing multiple outputs, TensorValue is always used. """ from afe.ir.quantization_utils import significant_bits_signed # Local import because of circular dependence sc, zp = self._observer.compute_scale_and_zp(qrange, self._node_reporter) if qrange is None: bits = 8 # The default quantization is for 8 bits. else: bits = max(significant_bits_signed(qrange[0]), significant_bits_signed(qrange[1])) bits = ((bits + 7) // 8) * 8 # Round up to multiple of 8 if bits == 24: bits = 32 # Round up to a natively supported int width assert bits <= 32, "Quantization with larger than 32 bits is not supported" if sc == 0: quant = Quantization.representable(0., zp, bits) else: quant = Quantization.representable(1.0 / sc, zp, bits) return TensorValue(quant)
[docs] def observer_type(self) -> Optional[CalibrationMethod]: return self._calibration_method
[docs] def asymmetry(self) -> bool: return self._observer.asymmetry()
[docs] def min_max(self) -> Tuple[float, float]: return self._observer.min_max()
[docs] def get_mean(self): assert isinstance(self._observer, _MeanObserver), "This method should only be called for _MeanObserver" return self._observer.get_mean()