Source code for afe.ir.quantization_utils

#########################################################
# Copyright (C) 2020 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Joey Chou
#########################################################
import dataclasses
import math

import numpy as np
from enum import Enum
from typing import List, Union, Tuple, Optional, Callable, cast, Dict, TypeVar, Any

from sima_utils.logging import sima_logger

import afe.ir.attributes as attributes
from afe.ir.defines import (
    QuantizedTensor, QuantizedParam, NodeName, Float, InputName, DataValue, Quantization,
    RequantMethod, TensorValue, TupleValue, QuantizedTensorNew, QuantizedTensorInt16, QuantizationCast, IdentityCast,
    QuantCast, DequantCast, RequantCast, ConvertCast
)
from afe.ir.tensor_type import ScalarType, scalar_is_integral
from afe.ir.utils import (
    transpose_axis_to_the_last, create_and_verify_narrowing
)
from afe.core.configs import QuantizationConfigs
from ml_kernels.math_helpers import RoundType
from ml_kernels.np_operators import ideal_udf
from ml_kernels.requantization import Narrowing, BaseRequantization, FractionalZeroRequantization, \
    TFLiteRequantization, ArithFoldedRequantization, narrow, narrowing_requantization
import ml_kernels.requantization
from ml_kernels.udf import ComputeLookupTable

_TENSOR = TypeVar("_TENSOR")

_ZERO_KERNEL_THRESHOLD = 0.001


[docs] class QNNDtype(str, Enum): """Data types used in QNN operations"""
[docs] INT8 = 'int8'
[docs] UINT8 = 'uint8'
[docs] INT32 = 'int32'
[docs] DTYPE_BOUNDS = {QNNDtype.INT8: (-128, 127), QNNDtype.UINT8: (0, 255), QNNDtype.INT32: (-2147483648, 2147483647)}
def _override_zero_scale(scale: Union[float, np.ndarray, List]) -> Union[float, np.ndarray, List]: """ If scale is equal to 0, set it to 1. Some part of the code will not behave properly is scale is 0. in that case we set the scale to 1. """ if isinstance(scale, np.ndarray): return np.where(scale == 0, 1.0, scale) elif isinstance(scale, List): return [sc if sc != 0 else 1.0 for sc in scale] else: assert isinstance(scale, float) return scale if scale != 0 else 1.0
[docs] def round_op(x: float, rounding_type: RoundType = RoundType.TOEVEN) -> float: """ Rounding to the nearest larger integer :param x: A float32 number to be rounded return: Rounded result """ if rounding_type == RoundType.UPWARD: return np.floor(x + 0.5) # round to +inf elif rounding_type == RoundType.TOEVEN: return np.round(x) # round to even elif rounding_type == RoundType.TONEAREST: return np.sign(x) * np.floor(np.abs(x) + 0.5) # round away from 0 elif rounding_type == RoundType.TRUNC: return np.trunc(x) else: raise ValueError(f"Expected {[t.name for t in RoundType]} for the rounding type, got {rounding_type}")
[docs] def calculate_normalization_shift(scale: Union[float, np.ndarray], rounding: RoundType = RoundType.TRUNC) -> Union[float, np.ndarray]: """ Calculate the number of shifts to normalize a scale. The original scale will be normalized, depending on the rounding type, after dividing (2**shift). """ from afe.ir.quantization_conv import decompose_power_of_2 if isinstance(scale, np.ndarray): exp, _ = decompose_power_of_2(scale, rounding=rounding) shift = exp.astype(int) else: exp, _ = decompose_power_of_2(np.asarray(scale), rounding=rounding) shift = int(exp) return shift
[docs] def get_bound(bits: int, signed: bool = True) -> int: if signed: bound = np.power(2., bits - 1) else: bound = np.power(2., bits) return bound.astype(QuantizedParam)
[docs] def clip_to_targeted_range(x: Union[int, np.ndarray], bits: int, restricted_range: bool = False) -> Union[int, np.ndarray]: """ Clip the x with targeted range determined by the given bit number. :param x: Numpy array or int :param bits: Number of bits used to determine the min and max number :param restricted_range: If true, the abs(a_min) == abs(a_max) """ bound = get_bound(bits) a_min = -bound a_max = bound - 1 if restricted_range: a_min += 1 return np.clip(x, a_min, a_max)
[docs] def compute_scale(asymmetry: bool, layer_bits: int, min_val: float, max_val: float, include_real_zero_point: bool = False, ) -> float: """ Compute a linear quantization scale for mapping the range (min_val, max_val) onto the quantized integer range determined by layer_bits, include_real_zero_point, and asymmetry. The computed scale is the reciprocal of the scale in TFLite's convention. :param asymmetry: If true, do asymmetric quantization. :param layer_bits: Number of bits used for quantization. :param min_val: Minimum value. :param max_val: Maximum value. :param include_real_zero_point: If True, force the float dynamic range covering zero. return: Computed scale s such that real numbers r are converted to integers q by the formula q = round(s * r). """ if include_real_zero_point: min_val = min(0, min_val) max_val = max(0, max_val) if asymmetry: # -128 ~ 127 # Measure the size of the range [MIN, MAX] max_integer = 2 ** layer_bits - 1 max_float = max_val - min_val else: # Measure the larger of the ranges [MIN, 0] and [0, MAX] max_integer = 2 ** (layer_bits - 1) - 1 max_float = max(abs(min_val), abs(max_val)) # Workaround for division by 0. In this case the range is a single point # and there is no meaningful scale factor. if max_val == min_val: max_float = max_val # 1/S = max_integer / max_float if max_float == 0: # The range is a single point and there is no meaningful # scale factor. Use special value 0 to signal this. return 0. return float(abs(max_integer / max_float))
[docs] def compute_zero_point(asymmetry: bool, layer_bits: int, min_val: float, max_val: float, restricted_range: bool = False ) -> int: """ Given min and max value, compute the zero point. :param asymmetry: If true, do asymmetric quantization. :param layer_bits: Number of bits used for quantization. :param min_val: Minimum value. :param max_val: Maximum value. :param restricted_range: If True, the dynamic range will be equal at negative and positive side. return: Zero point. """ if not asymmetry: return 0 if min_val == 0 and max_val != 0: ''' Somehow when the min_values == 0 and the zero point is -(2^layer_bits - 1), it decrease the accuracy a lot. So here we make sure the zero point is -(2^layer_bits) in such case. ''' return int(-np.power(2, layer_bits - 1)) if min_val == 0 and max_val == 0: ''' If all values are zero, then zp is also 0. ''' return 0 fp_zp = -0.5 * (min_val + max_val) * compute_scale( asymmetry, layer_bits, min_val, max_val, include_real_zero_point=True) if not restricted_range: # Bias toward negative side by 1 fp_zp -= (1. / 2) zp = int(round_op(fp_zp)) return int(clip_to_targeted_range(zp, layer_bits, restricted_range))
[docs] def significant_bits_signed(n: int) -> int: """ Get the smallest signed integer bit width that can represent the given integer. > significant_bits_signed(-129) = 9 > significant_bits_signed(-128) = 8 > significant_bits_signed(127) = 8 > significant_bits_signed(128) = 9 """ # Calculate bit-width in sign-magnitude representation bits = 1 + n.bit_length() # Convert to bit-width in twos complement representation if n < 0 and (-n & (-1 - n)) == 0: # n is a negative power of 2. Its size is 1 bit less than sign-magnitude representation. bits -= 1 return bits
[docs] def compute_power_of_2_scale_and_shift(scale: Union[float, np.ndarray], input_bit: int, output_bit: int ) -> Union[Tuple[int, int], Tuple[np.ndarray, np.ndarray]]: """ Given a float scale or a vector of scale and quantized bit number for input and output, return a quantized scale and right shift :param scale: Union[float, np.ndarray] :param input_bit: int. Number of bit used for input quantization :param output_bit: int. Number of bit used for output quantization :return: Union[Tuple[int, int], Tuple[np.ndarray, np.ndarray]. Tuple of (scale, right shift) """ right_shift = -calculate_normalization_shift(scale) quantized_scale = (scale * np.power(2., right_shift)).astype(np.float32) # Avoid overflow input_bound = get_bound(input_bit) output_bound = get_bound(output_bit) if isinstance(quantized_scale, np.ndarray): assert quantized_scale.ndim == 1 assert scale.shape == quantized_scale.shape for i in range(len(quantized_scale)): if round_op(quantized_scale[i] * (input_bound // 2)) == output_bound: right_shift[i] -= 1 quantized_scale[i] /= 2. else: if round_op(quantized_scale * (input_bound // 2)) == output_bound: right_shift -= 1 quantized_scale /= 2. right_shift += (input_bit - 1) quantized_scale *= input_bound return quantized_scale, right_shift
[docs] def compute_weight_scale(weight: np.ndarray, bits: int) -> float: """ Compute weight scale. Weights are always quantized symmetrically. :param weight: Weight tensor. :param bits: Number of bits used to quantize weight. return: Scale of weight. """ return compute_scale(asymmetry=False, layer_bits=bits, min_val=np.min(weight), max_val=np.max(weight), include_real_zero_point=True)
[docs] def compute_weight_scale_per_channel(weight: np.ndarray, bits: int) -> np.ndarray: """ Compute per-channel weight scales. The expected layout of weight is AwesomeConvWeightLayout. :param weight: Weight tensor in AwesomeConvWeightLayout format. :param bits: Number of bits used to quantize weight. return: An array of scales of weight. """ scales = np.zeros((weight.shape[-1])).astype(float) for i in range(weight.shape[-1]): scales[..., i] = compute_weight_scale(weight[..., i], bits) return scales
[docs] def linear_scale(input: np.ndarray, scale: float, bits: int, clip: bool = True) -> np.ndarray: """ Linear scale the input based on the scale. Clip the scaled input based on the bit number :param input: A numpy array. :param scale: A scale factor that used to scale the input to a target range. :param bits: Number of bit used to clip the scaled input. :param clip: If true, clip the linear scale result to the given dynamic range. return: Scaled input. """ res = input * scale if clip: res = round_op(res) res = clip_to_targeted_range(res, bits) return res.astype(QuantizedTensor)
[docs] def linear_scale_per_channel(input: np.ndarray, scale: np.ndarray, bits: int, clip: bool = True) -> np.ndarray: """ Linear scale the input based on the scale. Clip the scaled input based on the bit number The output channel has to be at the last dimension. :param input: A numpy array. :param scale: A numpy array of scale factors that used to scale the input to a different target ranges in different channels. :param bits: Number of bit used to clip the scaled input. :param clip: If true, clip the linear scale results to the given dynamic range. return: Scaled input. """ res = np.zeros(input.shape, dtype=QuantizedTensor) for i in range(input.shape[-1]): res[..., i] = linear_scale(input[..., i], scale[i], bits, clip) return res
[docs] def linear_quantize(input: np.ndarray, scale: float, zp: int, bits: int) -> np.ndarray: """ quantized_input = (input / S) + zero_point. :param input: A numpy array. :param scale: scale = (1/S) in the above equation. :param zp: Zero point of the quantized input. :param bits: Number of bit used to clip the scaled input. return Quantized input. """ scale = _override_zero_scale(scale) delta = -scale if scale < 0 else 1 / scale rounded = round_op(input / delta) + zp res = clip_to_targeted_range(rounded, bits) out_type = QuantizedTensorInt16 if bits == 16 else QuantizedTensorNew return np.array(res.astype(out_type))
[docs] def linear_quantize_with_quantization(input: np.ndarray, quantization: Quantization) -> np.ndarray: """ Apply a quantization to a floating-point tensor to produce a quantized tensor. :param input: Floating-point tensor :param quantization: Quantization to apply :return: Quantized tensor """ return linear_quantize(input, quantization.scale, quantization.zero_point, quantization.bits)
[docs] def quantize_value(value: Any, q: DataValue[Optional[Quantization]]) -> Any: """ Quantize a value according to the given quantization. Values consist of arrays and tuples. :param value: Value to quantize. It must consist of numpy arrays and tuples. :param q: Quantization of the value. None means that the value is not quantized and so it will be returned unchanged. :return: Quantized value. It has the same tuple structure as the input. """ if isinstance(q, TensorValue): assert isinstance(value, np.ndarray) if q.value is None: # Not quantized return value else: # Quantized return linear_quantize_with_quantization(value, q.value) elif isinstance(q, TupleValue): assert isinstance(value, (list, tuple)) assert len(value) == len(q.elements) return tuple(quantize_value(v, p) for v, p in zip(value, q.elements))
[docs] def dequantize_value(value: Any, q: DataValue[Optional[Quantization]]) -> Any: """ Dequantize a value according to the given quantization. Values consist of arrays and tuples. :param value: Value to dequantize. It must consist of numpy arrays and tuples. :param q: Quantization of the value. None means that the value is not quantized and so it will be returned unchanged. :return: Dequantized value. It has the same tuple structure as the input. """ if isinstance(q, TensorValue): assert isinstance(value, np.ndarray) if q.value is None: # Not quantized return value else: # Quantized return dequantize(value, 1 / q.value.scale, q.value.zero_point) elif isinstance(q, TupleValue): assert isinstance(value, (list, tuple)) assert len(value) == len(q.elements) return tuple(dequantize_value(v, p) for v, p in zip(value, q.elements))
[docs] def get_zero_kernel_mask_per_channel(weight: np.ndarray, threshold: float) -> np.ndarray: """ Return the mask of zero kernel. The kernel layout of weight must be in AwesomeConvWeightLayout. :param weight: Weights for convolution in AwesomeConvWeightLayout layout. :param threshold: If the sum of kernel's absolute value is smaller than the threshold, the kernel will be treated as a zero kernel. return: Mask of zero kernel. True means the kernel is a zero kernel. """ mask = [] for i in range(weight.shape[-1]): mask.append(np.sum(abs(weight[..., i])) <= threshold) return np.array(mask)
[docs] def dequantize(input: np.ndarray, scale: float, zp: int) -> np.ndarray: """ Original equation: quantized_input = (input / S) + zero_point. Reverse it to get dequantize equation: dequantized input = (quantized_input - zero_point) * S :param input: A numpy array. :param scale: scale = (1 / S) in the above equation. :param zp: Zero point of the quantized input. return Dequantized input. """ return ((input.astype(np.int32) - zp) * scale).astype(Float)
def _requantize_half_up(input: np.ndarray, bits: int, right_shift: int, zp: Optional[int] = None) -> np.ndarray: """ Requantize a quantized tensor to another quantization domain :param input: A numpy array. :param bits: Number of bit used to clip the scaled input. :param right_shift: Number of bit shifted to the right. This acts as a hardware friendly multiple of 2 scale. :param zp: Zero point of the quantized input. return: Requantized tensor in int format. """ HALF = 1 << (right_shift - 1) if right_shift > 0 else 0 res = (input.astype(QuantizedTensor) + HALF) >> right_shift if zp is not None: res += zp return clip_to_targeted_range(res, bits).astype(QuantizedTensor) def _requantize_even(input: np.ndarray, bits: int, right_shift: int, zp: Optional[int] = None) -> np.ndarray: """ Requantize a quantized tensor to another quantization domain with the round-to-even mode. :param input: A numpy array. :param bits: Number of bit used to clip the scaled input. :param right_shift: Number of bit shifted to the right. This acts as a hardware friendly multiple of 2 scale. :param zp: Zero point of the quantized input. return: Requantized tensor in int format. """ res = input.astype(np.int32) res = res / (2 ** right_shift) res = np.round(res) if zp is not None: res += zp return clip_to_targeted_range(res, bits).astype(QuantizedTensor)
[docs] def requantize(data: np.ndarray, bits: int, right_shifts: Union[int, np.ndarray], zp: Optional[int] = None, per_channel: bool = False, axis: int = -1, rounding_type: RoundType = RoundType.UPWARD, *, result_type: ScalarType = ScalarType.int8) -> np.ndarray: """ Requantize a quantized tensor to another quantization domain :param data: A numpy array. :param bits: Number of bit used to clip the scaled input. :param right_shifts: A numpy array. Each ouput channel has a number of bit shifted to the right. This acts as a hardware friendly multiple of 2 scale. :param zp: Zero point of the quantized input. :param per_channel: Default is False. If True, each output channel has one right_shift. :param result_type: Numeric type of requantized tensor. return: Requantized tensor in chosen numeric type. """ # Only support signed integer types up to 32 bit assert result_type in (ScalarType.int8, ScalarType.int32) np_result_type = result_type.numpy_type() if rounding_type == RoundType.TOEVEN: requantize_fn = _requantize_even else: requantize_fn = _requantize_half_up if not per_channel: return requantize_fn(data, bits, right_shifts, zp).astype(np_result_type) right_shifts = cast(np.ndarray, right_shifts) if axis != -1 and axis != data.ndim - 1: data = transpose_axis_to_the_last(data, axis) res = np.zeros(data.shape, dtype=QuantizedTensor) for i in range(data.shape[-1]): res[..., i] = requantize_fn(data[..., i], bits, right_shifts[i], zp) if axis != -1 and axis != data.ndim - 1: res = transpose_axis_to_the_last(res, axis) return res.astype(np_result_type)
[docs] def float_requantization(input_quantization: Quantization, output_quantization: Quantization) -> Tuple[float, float]: """ Calculate floating-point correction parameters to requantize integer data using floating-point intermediate values. It returns S and Z such that data can be requantized by the calculation: quantized_output = round(S * float(quantized_input) + Z) :param input_quantization: Quantization of input data :param output_quantization: Quantization of output data :return: Requantization scale correction and zero point correction """ # We have seen multiply by zero in mobilefacedet_v1_mxnet and colorization-siggraph.onnx (SWMLA-4648) # If input is zero, return 0 for scale correction and 0 for zp correction if input_quantization.scale == 0: sima_logger.sima_log_warning("Returning zero for requantization scale correction because of zero input. " "Please check the model for nodes that are algebraically equivalent to " "a zero constant (for example, multiply with a zero constant.") return 0.0, 0.0 sc_correction: float = output_quantization.scale / input_quantization.scale assert sc_correction > 0 zp_correction: float = output_quantization.zero_point - sc_correction * input_quantization.zero_point return sc_correction, zp_correction
def _is_close_to_pow2(n: float) -> bool: """ Return true if the number is close to a power of 2. The tolerance is roughly 1 part in 256. """ log_remainder = math.log2(n) % 1 return log_remainder < 0.05 or (1 - log_remainder) < 0.05
[docs] def power_of_2_requantization(input_quantization: Quantization, output_quantization: Quantization) -> int: """ Calculate a shift factor to requantize data by a power of 2 in integer arithmetic. This should only be used if the input and output quantization were chosen for power of 2 requantization. It is not a good approximation in general. It returns a shift such that data can be requantized by the calculation: quantized_output = quantized_input >> shift The shift should use rounding to nearest, with any tie-breaking method. :param input_quantization: Quantization of input data :param output_quantization: Quantization of output data :param bits: Integer precision of temporary values :return: Amount to shift right. May be negative. """ sc_correction, zp_correction = float_requantization(input_quantization, output_quantization) shift = -calculate_normalization_shift(sc_correction, rounding=RoundType.TONEAREST) # This method does not do zero point correction; effectively the correction is 0. # Calculate the zero point correction and verify that it would round to zero after # shifting. zp_correction = round(zp_correction * 2**shift) assert abs(zp_correction) <= 0.5 * (1 + 2**shift), \ "Power-of-2 requantization is not suitable for producing the wanted quantization" return shift
def _rescale_by_common_factor_of_two(scale: int, zp: int, shift: int) -> Tuple[int, int, int]: """ Eliminate the common factor of scale, zp and shift using the find-first-set technique (x & -x) to find the max power of 2 factor of int_sc_correction and int_zp_correction. :param scale: Integer requantization scale. :param zp: Requantization zero point. :param shift: Requantization shift. :return: """ shift_scale = 2 ** shift common_power_of_2_factor = min(shift_scale, scale & -scale) if zp != 0: common_power_of_2_factor = min( common_power_of_2_factor, zp & -zp) scale //= common_power_of_2_factor zp //= common_power_of_2_factor shift -= common_power_of_2_factor.bit_length() - 1 return scale, zp, shift
[docs] def requantization(input_quantization: Quantization, output_quantization: Quantization, bits: int = 32, *, sc_correction_bits: int = 32) -> Tuple[int, int, int]: """ Calculate correction factors to requantize data in integer arithmetic. It returns S, Z, and shift such that data can be requantized by the calculation: quantized_output = ((S * quantized_input) + Z) >> shift The shift should use rounding to nearest, with any tie-breaking method. :param input_quantization: Quantization of input data :param output_quantization: Quantization of output data :param bits: Integer precision of temporary values :param sc_correction_bits: Integer precision of the scale correction. The returned scale correction, taken as a signed integer, will not exceed this many bits. :return: Requantization scale correction, zero point correction, and right shift """ assert output_quantization.bits <= input_quantization.bits assert sc_correction_bits > 1 if input_quantization == output_quantization: # Requantization is not needed return 1, 0, 0 # Calculate correction factors in floating point sc_correction, zp_correction = float_requantization(input_quantization, output_quantization) if sc_correction == 0: assert zp_correction == 0, \ f"Expect zero for zp correction when scale correction is zero, but got {zp_correction}" return 0, 0, 0 # Find value range of intermediate results (S * quantized_input) # and (S * quantized_input + Z), as they would be if the floating-point # correction factors are used directly input_qmin: float = input_quantization.scale * input_quantization.min_val + input_quantization.zero_point input_qmax: float = input_quantization.scale * input_quantization.max_val + input_quantization.zero_point intermediate_qmin = sc_correction * input_qmin + min(zp_correction, 0) intermediate_qmax = sc_correction * input_qmax + max(zp_correction, 0) intermediate_abs_max_value = max(abs(intermediate_qmin), abs(intermediate_qmax)) # Value range of the integer type that is used for calculations. # We pretend it's a symmetric range [-2^(b-1) ... 2^(b-1)] to simplify the math. int_max = 2**(bits-1) # Find scale factor for quantizing the correction factors. Choose a factor # that scales up intermediate_abs_max_value to fit the integer value range. assert intermediate_abs_max_value <= int_max, \ "Potential overflow was detected in requantization arithmetic" base_shift_scale = int_max / intermediate_abs_max_value shift_bits = calculate_normalization_shift(base_shift_scale) # Ensure that the output has at least as many significant bits as the input output_bits = min(output_quantization.bits, _symmetric_quant_bits(output_quantization.scale, output_quantization.min_val, output_quantization.max_val)) min_preserved_bits = min(16, output_bits) shift_bits = min(shift_bits, bits - min_preserved_bits) # Ensure that the scale correction is representable in sc_correction_bits max_scale_bits = (1 << (sc_correction_bits - 1)) - 1 # Maximum representable value estimated_sc_correction = sc_correction * (2**shift_bits) shift_adjustment = max(0, math.ceil(math.log2(estimated_sc_correction / max_scale_bits))) shift_bits -= shift_adjustment assert shift_bits >= 0, f"Cannot represent scale correction with {sc_correction_bits} bits" shift_scale = 2**shift_bits # Quantize sc_correction and zp_correction with the chosen scale factor int_sc_correction = round(sc_correction * shift_scale) int_zp_correction = round(zp_correction * shift_scale) # Zero scale correction factor means that the output would be constant zero, # all information would be lost by requantization assert int_sc_correction > 0, "Error in quantization" # Eliminate the common factor of scale, zp and shift. return _rescale_by_common_factor_of_two(int_sc_correction, int_zp_correction, shift_bits)
[docs] def requantization_tflite(input_quantization: Quantization, output_quantization: Quantization) -> Tuple[int, int, int]: """ Calculate correction factors to do TFLite requantization. It returns S, Z, and shift such that data can be requantized by the calculation: quantized_output = ((S * quantized_input) >> shift) + Z The shift should use rounding to nearest, with any tie-breaking method. The product (S * quantized_input) is assumed not to overflow. It is designed for a datapath that calculates this product in 64-bit precision. :param input_quantization: Quantization of input data. The input data's zero point must be 0. :param output_quantization: Quantization of output data :return: Requantization scale correction, zero point correction, and right shift """ assert output_quantization.bits <= input_quantization.bits assert input_quantization.zero_point == 0 if input_quantization == output_quantization: # Requantization is not needed return 1, 0, 0 # Calculate correction factors in floating point sc_correction, zp_correction = float_requantization(input_quantization, output_quantization) assert zp_correction == output_quantization.zero_point # Quantize sc_correction to 8 bits from afe.ir.quantization_conv import decompose_power_of_2 exponent, frac_sc_correction = decompose_power_of_2(np.array(sc_correction), RoundType.UPWARD) shift = int(-exponent.item() + 8) int_sc_correction = round(frac_sc_correction.item() * (2**8)) # Shift value should be strictly smaller than the number of bits in (S * quantized_input) assert 0 <= shift < input_quantization.bits + 8, "Requantization scale factor is out of range" # Eliminate the common factor of scale and shift int_sc_correction, _, shift = _rescale_by_common_factor_of_two(int_sc_correction, 0, shift) return int_sc_correction, output_quantization.zero_point, shift
################### # Runtime utilities ###################
[docs] def is_quantized(data: np.ndarray) -> bool: return data.dtype in [np.int8, np.int16, np.int32]
[docs] def dequantize_tensor(data: Union[List[np.ndarray], Tuple[np.ndarray, ...], np.ndarray], scales: List[float], zps: List[int]) \ -> Union[List[np.ndarray], Tuple[np.ndarray, ...], np.ndarray]: """ Dequantize tensor. A tensor can be a List[int], a Tuple[np.ndarray, ...], or a np.ndarray. """ scales = _override_zero_scale(scales) if isinstance(data, (tuple, list)): dequantized_data = [] for _data, _scale, _zp in zip(data, scales, zps): if is_quantized(_data): dequantized_data.append(dequantize(_data, 1. / _scale, _zp)) else: dequantized_data.append(_data) # Cast to tuple if isinstance(data, tuple): dequantized_data = tuple(dequantized_data) elif is_quantized(data): dequantized_data = dequantize(data, 1. / scales[0], zps[0]) else: return data # Make sure it returns a correct type assert type(dequantized_data) == type(data) return dequantized_data
[docs] def quantize_tensor(data: Union[Tuple[np.ndarray, ...], np.ndarray], scales: List[Union[float, List[float]]], zps: List[Union[int, List[int]]], layer_bits: List[Union[int, List[int]]], ) -> Union[Tuple[np.ndarray, ...], np.ndarray]: """ Quantize tensor. A tensor can be Tuple[np.ndarray, ...] or a np.ndarray. """ if isinstance(data, tuple): quantized_data = [] for i, _data in enumerate(data): if not is_quantized(_data): quantized_data.append(linear_quantize(_data, scales[i], zps[i], layer_bits[i])) else: quantized_data.append(_data) data = tuple(quantized_data) elif not is_quantized(data): data = linear_quantize(data, scales[0], zps[0], layer_bits[0]) return data
[docs] def dequantize_input_dict(input_dict: Dict[NodeName, Union[np.ndarray, Tuple[np.ndarray, ...]]], scales: List[Union[float, List[float]]], zps: List[Union[int, List[int]]], ) -> Dict[NodeName, Union[np.ndarray, Tuple[np.ndarray, ...]]]: """ Given a input_dict, input scales, and input zero points, dequantize each input in the input_dict to float if the data type is QuantizedTensor. :param input_dict: Dict[NodeName, Union[np.ndarray, Tuple[np.ndarray, ...]]]. Input dictionary with (key: value) = (input_name: data) :param scales: List[Union[float, List[float]]]. Input scale for each input data :param zps: List[Union[int, List[int]]]. Input zero point for each input data :return: A dequantized input_dict """ float_input_dict: Dict[NodeName, Union[np.ndarray, Tuple[np.ndarray, ...]]] = {} for i, (in_name, in_data) in enumerate(input_dict.items()): float_input_dict[in_name] = dequantize_tensor(in_data, scales[i], zps[i]) return float_input_dict
[docs] def quantize_input_dict(input_dict: Dict[NodeName, Union[np.ndarray, Tuple[np.ndarray, ...]]], scales: List[Union[float, List[float]]], zps: List[Union[int, List[int]]], layer_bits: List[Union[int, List[int]]] ) -> Dict[NodeName, Union[np.ndarray, Tuple[np.ndarray, ...]]]: """ Given a input_dict, input scales, and input zero points, quantize each input in the input_dict to QuantizedTensor if the data type is not QuantizedTensor. :param input_dict: Dict[NodeName, Union[np.ndarray, Tuple[np.ndarray, ...]]]. Input dictionary with (key: value) = (input_name: data) :param scales: List[Union[float, List[float]]]. Input scale for each input data :param zps: List[Union[int, List[int]]]. Input zero point for each input data :param layer_bits: Int, number of bit precision for QuantizedTensor :return: A quantized input_dict """ quantized_input_dict: Dict[NodeName, Union[np.ndarray, Tuple[np.ndarray, ...]]] = {} for i, (in_name, in_data) in enumerate(input_dict.items()): quantized_input_dict[in_name] = quantize_tensor(in_data, scales[i], zps[i], layer_bits[i]) return quantized_input_dict
####################### # Operator Quantization #######################
[docs] def quantize_alpha(alpha: np.ndarray, bits: int = 8) -> Tuple[np.ndarray, int]: """ Quantize the alpha for PreluOp :param alpha: Alpha :param bits: Number of bits used for quantization :return: Quantized alpha, shift value """ alpha_scale = compute_weight_scale(alpha, bits) right_shift = calculate_normalization_shift(alpha_scale) # Shift factor needs to be limited to 31. For alpha values which correspond to # larger shift (range very close to zero), alpha values are quantized to zero values, # which effectively results in PReLU operator becoming a regular ReLU operator. right_shift = min(right_shift, 31) quant_alpha = np.around(alpha * np.power(2.0, right_shift)).astype(np.int8) return quant_alpha, right_shift
[docs] def quantize_add_subtract(is_subtract: bool, input_scales: List[float], input_zps: List[int], scale: float, zero_point: int, layer_bits: int, in1_scale_const: int = 1, in2_scale_const: int = 1) -> Tuple[List[int], int, int]: """ Quantize the add/subtact operator :param is_subtract: If True function is used to quantize subtract operator, otherwise add operator. :param input_scales: Scales of the input nodes. :param input_zps: Zero points of the input nodes. :param scale: Scale of the current node. :param zero_point: Zero point of the current node. :param layer_bits: Number of bits used for quantization. :param attrs: AwesomeAttributes class :param activ_attrs: Activation function used in case of composite operations. :param in1_scale_const: Const to be folded in 1st input scale. :param in2_scale_const: Const to be folded in 2nd input scale. """ # Only scalar constants are supported at this point assert np.isscalar(in1_scale_const) assert np.isscalar(in2_scale_const) in1_zp, in2_zp = input_zps layer_zp = zero_point in1_scale, in2_scale = input_scales layer_scale = _override_zero_scale(scale) in1_out_scale = layer_scale / in1_scale * in1_scale_const if in1_scale != 0. else layer_scale / in1_scale_const in2_out_scale = layer_scale / in2_scale * in2_scale_const if in2_scale != 0. else layer_scale / in2_scale_const # Normally, with non-zero inputs, select the max to normalize # However, scale=1 and zp=0 are hard coded for zero tensors # If one input is zero, select non-zero input for normalization in_out_scale = max(in1_out_scale, in2_out_scale) if in1_scale == 0. and in1_zp == 0: in_out_scale = in2_out_scale elif in2_scale == 0. and in2_zp == 0: in_out_scale = in1_out_scale # Compute the number of bit shifts to normalize to [0.5, 1) layer_sh = calculate_normalization_shift(abs(in_out_scale)) + 1 in_out_scale = in_out_scale / np.power(2.0, layer_sh) in1_out_scale = in1_out_scale / np.power(2.0, layer_sh) in2_out_scale = in2_out_scale / np.power(2.0, layer_sh) bound = get_bound(layer_bits) if round_op(in_out_scale * bound) >= bound: layer_sh += 1 in1_out_scale /= 2 in2_out_scale /= 2 in1_scale = min(in1_out_scale * bound, bound - 1) in2_scale = min(in2_out_scale * bound, bound - 1) right_shift = -(layer_sh - (layer_bits - 1)) in2_zp = -in2_zp if is_subtract else in2_zp zp_corr = layer_zp * np.power(2.0, -layer_sh) - \ (in1_zp * in1_out_scale + in2_zp * in2_out_scale) in_scale = max(in1_scale, in2_scale) while in_scale > bound - 1: in1_scale /= 2 in2_scale /= 2 zp_corr /= 2 right_shift -= 1 in_scale /= 2 lhs_scale = int(round_op(in1_scale)) rhs_scale = int(round_op(in2_scale)) zp_correction = int(round_op(zp_corr * bound)) return [lhs_scale, rhs_scale], zp_correction, right_shift
def _symmetric_quant_bits(scale: float, min_val: float, max_val: float) -> int: """ Find the number of bits of precision that would be required for using the given quantization. The quantization being analyzed is q = r * scale, where min_val <= r <= max_val. :param scale: Quantization scale :param min_val: Minimum value in the real number range :param max_val: Maximum value in the real number range :return: Integer bit width required to represent quantized values """ assert min_val <= 0 <= max_val qmin = round(scale * min_val) qmax = round(scale * max_val) return max(significant_bits_signed(qmin), significant_bits_signed(qmax))
[docs] def quantize_multiply(lhs_quant: Quantization, rhs_quant: Quantization, output_quant: Quantization, allow_full_output_precision: bool) \ -> Tuple[int, BaseRequantization[np.ndarray], Quantization]: """ Quantize the multiply operator. :param lhs_quant: Quantization of the first input of multiply :param rhs_quant: Quantization of the second input of multiply :param output_quant: Quantization of the output of multiply. It may be ignored if allow_full_output_precision is True. :param allow_full_output_precision: Whether 32-bit output is allowed. If True, then this function may ignore output_quant and output a 32-bit quantization. If false, then this function will quantize according to output_quant. :return: Tuple of intrinsic shift amount, requantization to perform, and quantization of the output. """ assert output_quant.bits in (8, 16) output_quant_type = np.int8 if output_quant.bits == 8 else np.int16 # Handle zero input tensors or zero output tensor if lhs_quant.scale == 0 or rhs_quant.scale == 0 or output_quant.scale == 0: requant = FractionalZeroRequantization(1, output_quant.zero_point, create_and_verify_narrowing(0, RoundType.TOEVEN, output_quant_type)) return 0, requant, output_quant # Find properties of the quantized product, (x - zp_x)*(y - zp_y). product_scale = lhs_quant.scale * rhs_quant.scale product_min = min(lhs_quant.min_val * rhs_quant.max_val, lhs_quant.max_val * rhs_quant.min_val) product_max = max(lhs_quant.min_val * rhs_quant.min_val, lhs_quant.max_val * rhs_quant.max_val) assert product_min <= 0 <= product_max # Ensure no arithmetic saturation. If the integer product is larger than a 31-bit integer, # use right-shift to make it smaller. Calculation is based on 31 bits to leave one # extra bit for adding zp_correction. product_bits = _symmetric_quant_bits(product_scale, product_min, product_max) intrinsic_shift = max(0, product_bits - 31) product_scale /= (1 << intrinsic_shift) product_quant = Quantization(scale=product_scale, zero_point=0, bits=32, min_val=product_min, max_val=product_max) # Recalculate bits with the adjusted scale product_bits = _symmetric_quant_bits(product_scale, product_min, product_max) # If the product has no more than 16 bits of precision, then requantization can be done # entirely in 32-bit precision, so use FractionalZero requantization or omit the requantization. # Otherwise, use TFLite requantization. if product_bits <= 16: if allow_full_output_precision: # Omit the requantization. Use the 32-bit intermediate result as the output. requant = FractionalZeroRequantization(1, 0, create_and_verify_narrowing(0, RoundType.TOEVEN, np.int32)) output_quant = product_quant else: # Requantize using FractionalZeroRequantiation sc_corr, zp_corr, shift = requantization(product_quant, output_quant) requant = FractionalZeroRequantization(sc_corr, zp_corr, create_and_verify_narrowing(shift, RoundType.TOEVEN, output_quant_type)) else: # Requantize using TFLiteRequantization sc_corr, zp_corr, shift = requantization_tflite(product_quant, output_quant) requant = TFLiteRequantization(sc_correction=sc_corr, zp_correction=zp_corr, shift=shift, rounding=RoundType.TOEVEN, out_dtype=output_quant_type) return intrinsic_shift, requant, output_quant
[docs] def quantize_batch_matmul(lhs_quant: Quantization, rhs_quant: Quantization, output_quant: Quantization) \ -> Tuple[int, BaseRequantization[np.ndarray], Quantization]: assert output_quant.bits in (8, 16) output_quant_type = np.int8 if output_quant.bits == 8 else np.int16 # Handle zero input tensors or zero output tensor if lhs_quant.scale == 0 or rhs_quant.scale == 0 or output_quant.scale == 0: requant = FractionalZeroRequantization(1, output_quant.zero_point, create_and_verify_narrowing(0, RoundType.TOEVEN, output_quant_type)) return 0, requant, output_quant scale_bits = 15 if output_quant_type == np.int16 else 7 max_shift = 15 if output_quant_type == np.int16 else 23 product_scale = lhs_quant.scale * rhs_quant.scale req_scale = output_quant.scale / product_scale shift = int(-np.ceil(np.log2(req_scale)).astype(np.int32)) + scale_bits sc_corr = int(np.round(req_scale * np.power(2.0, shift)).astype(np.int32)) intrinsic_shift = int(np.maximum(shift - max_shift, 0)) if output_quant_type == np.int8: assert intrinsic_shift == 0, "Intrinsic shift should be 0 for int8." shift -= intrinsic_shift zp_corr = output_quant.zero_point << shift requant = FractionalZeroRequantization(sc_corr, zp_corr, create_and_verify_narrowing(shift, RoundType.TOEVEN, output_quant_type)) return intrinsic_shift, requant, output_quant
[docs] def quantize_udf( input_quant: Quantization, output_quant: Quantization, input_type: type, output_type: type, func: Callable[[np.ndarray], np.ndarray], invert_scales: bool = True, ) -> np.ndarray: """ Create a lookup table for a user-defined function. :param input_quant: Quantization of the input. :param output_quant: Quantization of the output. :param input_type: Type of LUT input. :param output_type: Type of LUT output. :param func: Function to be approximated by the lookup table. :param invert_scales: If true, the input scale factors are inverted. :return: Lookup table representing func for the quantized input and output. It is a numpy array of int8 or int16 values. """ if invert_scales: input_scale = 1.0 / _override_zero_scale(input_quant.scale) output_scale = 1.0 / _override_zero_scale(output_quant.scale) else: input_scale = input_quant.scale output_scale = output_quant.scale lut = ComputeLookupTable( func, input_quant=(input_scale, input_quant.zero_point), output_quant=(output_scale, output_quant.zero_point), input_type=input_type, output_type=output_type ) return lut.build()
[docs] def get_input_quantization_func(scale: float, zp: int, layer_bit: int) -> Callable[[np.ndarray], np.ndarray]: """ Return a function that takes a numpy array and using the scale and zero point to quantize the data using the equation below: quantized_input = (input / S) + zero_point :param input: A numpy array. :param scale: scale = (1/S) in the above equation. :param zp: Zero point of the quantized input. :param bits: Number of bit used to clip the scaled input. """ def _quant_func(data: np.ndarray) -> np.ndarray: return linear_quantize(data, scale, zp, layer_bit).astype(np.int8) return _quant_func
def _simplify_quantized_clip( attrs: attributes.ClipQuantAttrs, zero_point: int, scalar_type: ScalarType, *, asymmetry: bool = False ) -> Union[attributes.ClipQuantAttrs, attributes.ReluQuantAttrs, None]: """ Try to simplify a clip activation function in the quantized domain. Return the attributes for the simplified activation function (clip, relu, or nothing). :param attrs: Quantized attributes of clip. :param zero_point: Zero point in the quantized domain. Determines if clip can be converted to relu. :param scalar_type: ScalarType used to initialize ReluAttrs. Has to be integer type. :param asymmetry: Whether asymmetric quantization is used. This is only used for error checking. :return: Simplified activation function. """ clip_min = attrs.a_min clip_max = attrs.a_max iinfo = np.iinfo(scalar_type.numpy_type()) representable_min = iinfo.min representable_max = iinfo.max # Clipping to the representable range is implicitly carried out by saturating arithmetic, and supersedes # clipping to any looser range. if clip_min <= representable_min and clip_max >= representable_max: return None # Asymmetric quantization should always be simplified. if asymmetry: sima_logger.sima_log_info("A clip operator was not eliminated with asymmetric quantization. " "This may indicate precision loss.") # Clipping to the positive range can be simplified to relu. if clip_min == zero_point and clip_max >= representable_max: return attributes.ReluQuantAttrs(attrs.shape, zero_point) # Can't simplify clip. return attrs def _simplify_quantized_relu(attrs: attributes.ReluQuantAttrs, zero_point: int, scalar_type: ScalarType) \ -> Union[attributes.ReluQuantAttrs, None]: """ Try to simplify a relu activation function in the quantized domain. Return the attributes for the simplified activation function (relu or nothing). :param attrs: Quantized attributes of clip. :param zero_point: Zero point in the quantized domain. Determines if clip can be converted to relu. :param scalar_type: ScalarType used to initialize ReluAttrs. Has to be integer type. :return: Simplified activation function. """ representable_min = np.iinfo(scalar_type.numpy_type()).min # If zero point is the minimum value, then relu will be implicitly carried out by saturating arithmetic. if zero_point == representable_min: return None # Can't simplify relu. return attributes.ReluQuantAttrs(attrs.input_shape, zero_point)
[docs] def quantize_clip_attrs(attrs: attributes.ClipAttrs, scalar_type: ScalarType, quant: Quantization) -> attributes.ClipQuantAttrs: """ Quantize the attributes of clip operator Calculate the boundaries of the clip operator based on its quantization parameters and data type. Args: attrs: Attributes of the clip operator scalar_type: Scalar data type of the quantized clip operator quant: Quantization parameters to apply to clip operator Returns: Attributes of the quantized clip operator containing boundary parameters calculated for quantized operator. """ iinfo = np.iinfo(scalar_type.numpy_type()) representable_min = iinfo.min representable_max = iinfo.max quantized_clip_min = max(representable_min, round(attrs.a_min * quant.scale + quant.zero_point)) quantized_clip_max = min(representable_max, round(attrs.a_max * quant.scale + quant.zero_point)) return attributes.ClipQuantAttrs(quantized_clip_min, quantized_clip_max, attrs.shape, scalar_type)
[docs] def quantize_activation(attrs: Union[attributes.ClipAttrs, attributes.ReluAttrs, None], quantization: Quantization, scalar_type: ScalarType, *, quant_config: Optional[QuantizationConfigs] = None) \ -> Union[attributes.ClipQuantAttrs, attributes.ReluQuantAttrs, None]: """ Quantize a simple activation function (clip, relu, or nothing) and simplify it if possible. No requantization is introduced to these activation functions; the input and output quantization scales are always the same. Quantization may simplify an activation function by taking advantage of the clipping behavior of saturating arithmetic. :param attrs: Attributes of the activation function to quantize :param quantization: Quantization to apply to this activation function :param scalar_type: Scalar data type that the activation function will be evaluated on :param scalar_type: ScalarType used to initialize ReluAttrs. Has to be integer type. :param quant_config: Parameters that were used to choose 'quantization'. Used for error checking. :return: Attributes of the quantized activation function. It may be a different type than the input. """ assert scalar_is_integral(scalar_type) if attrs is None: return None if quantization.scale == 0: # The input of the activation operator is constant zeros. The activations do nothing. return None if isinstance(attrs, attributes.ClipAttrs): clip_attrs = quantize_clip_attrs(attrs, scalar_type, quantization) return _simplify_quantized_clip(clip_attrs, quantization.zero_point, scalar_type, asymmetry=quant_config and quant_config.asymmetry.get()) else: assert isinstance(attrs, attributes.ReluAttrs) return _simplify_quantized_relu(attrs, quantization.zero_point, scalar_type)
[docs] def requantize_activation(attrs: Union[attributes.ClipQuantAttrs, attributes.ReluQuantAttrs, None], zero_point: int, requantization: BaseRequantization[np.ndarray], scalar_type: ScalarType) \ -> Union[attributes.ClipQuantAttrs, attributes.ReluQuantAttrs, None]: """ Requantize an activation function. This represents transforming the expression requant(activ(x)), where the activation is evaluated before requantization, to an equivalent expression newactiv(requant(x)), where the new activation is evaluated after requantization. The new activation could be simpler by taking advantage of integer saturation. :param attrs: Activation function's attributes. This must be for a quantized activation. :param zero_point: Original zero point of the activation function, before requantization. Ignored if attrs is None. :param requantization: Requantization to perform. The input type of the requantization is assumed to be int16. :param scalar_type: ScalarType used to initialize ReluAttrs. Has to be integer type. :return: Transformed activation function's attributes (clip, relu, or nothing). """ assert scalar_is_integral(scalar_type) if attrs is None: return None new_zero_point = int(ml_kernels.requantization.requantize(np.array(zero_point), requantization).item()) if isinstance(attrs, attributes.ClipQuantAttrs): a_min = ml_kernels.requantization.requantize(np.array(attrs.a_min), requantization).item() a_max = ml_kernels.requantization.requantize(np.array(attrs.a_max), requantization).item() return _simplify_quantized_clip(attributes.ClipQuantAttrs(a_min, a_max, attrs.shape, scalar_type), new_zero_point, scalar_type) else: assert isinstance(attrs, attributes.ReluQuantAttrs) return _simplify_quantized_relu(attrs, new_zero_point, scalar_type)
def _convert_shift_to_multiplier(shift: int | np.ndarray) -> float: """ Convert a per-tensor shift to a multiplier such that x >> shift == x * _convert_shift_to_multiplier(shift). Raise an error if the shift is per-channel. :param shift: Shift value :return: Multiplier value """ assert isinstance(shift, int), "Only per-tensor shift is supported" return 1.0 / (1 << shift)
[docs] def requantize_quantization(quantization: Quantization, requant: BaseRequantization[np.ndarray]) -> Quantization: """ Get the quantization of the result of requantizing a tensor. This would be the quantization at the output of a Requantize node, for the given input and requantization. :param quantization: Quantization of input tensor :param requant: Requantization to perform :return: Quantization of the result of applying requant to the input tensor """ iinfo = np.iinfo(requant.out_dtype) # If input is a zero tensor, output must also be a zero tensor if quantization.scale == 0: if isinstance(requant, (TFLiteRequantization, FractionalZeroRequantization)): assert requant.zp_correction == 0, "Cannot handle zero point correction on a zero tensor" return Quantization(0.0, 0, iinfo.bits, 0.0, 0.0) # Given the old quantization q = r * scale + zero_point and requantization q2 = requant(q), # find the new quantization q2 = r * scale2 + zero_point2. if isinstance(requant, FractionalZeroRequantization): pow2_scale = _convert_shift_to_multiplier(requant.narrowing.shift) new_scale = pow2_scale * quantization.scale * requant.sc_correction new_zp = round(pow2_scale * (quantization.zero_point * requant.sc_correction + requant.zp_correction)) elif isinstance(requant, TFLiteRequantization): pow2_scale = _convert_shift_to_multiplier(requant.shift) new_scale = pow2_scale * quantization.scale * requant.sc_correction new_zp = round(pow2_scale * quantization.zero_point * requant.sc_correction) + requant.zp_correction elif isinstance(requant, ArithFoldedRequantization): pow2_scale = _convert_shift_to_multiplier(requant.narrowing.shift) new_scale = pow2_scale * quantization.scale new_zp = round(pow2_scale * quantization.zero_point) else: raise TypeError("Unrecognized requantization type") # Clip (min_val, max_val) to the representable range. # This models the effect of saturating arithmetic in the requantization operator. min_val = max(quantization.min_val, (iinfo.min - new_zp) / new_scale) max_val = min(quantization.max_val, (iinfo.max - new_zp) / new_scale) assert min_val < max_val return Quantization(new_scale, new_zp, iinfo.bits, min_val, max_val)
[docs] def quantize_prelu(layer_bits: int, alpha: Union[np.ndarray, float]) -> Tuple[int, int]: """ Quantized the PRelu alphas and return the quantized alphas and right shifts :param layer_bits: Number of bits used for quantization :param alpha: Union[np.ndarray, float]. alpha in float data type return: Tuple[np.ndarray, np.ndarray]. Tuple of (quantized alpha, right shift) """ alpha, right_shift = compute_power_of_2_scale_and_shift(alpha, layer_bits, 8) quantized_alpha = round_op(alpha) # Cast to Quantized parameter data type if quantized_alpha.size == 1: quantized_alpha = QuantizedParam(quantized_alpha) else: quantized_alpha = quantized_alpha.astype(QuantizedParam) return quantized_alpha, right_shift
[docs] def quantize_reciprocal(input_qtype: attributes.QuantResultTensorType) -> attributes.AwesomeCalibAttrs: """ Quantize the reciprocal part of divide :param input_qtype: quantization for rhs argument of divide. :return: calibration attributes AwesomeCalibAttrs which are used in ReciprocalOp UDF. """ assert isinstance(input_qtype.quant, Quantization) reciprocal_min_value: float = 1.0 / input_qtype.quant.min_val reciprocal_max_value: float = 1.0 / input_qtype.quant.max_val reciprocal_scale: float = 127.0 / max(abs(reciprocal_min_value), abs(reciprocal_max_value)) reciprocal_zp = 0 reciprocal_quantization = Quantization(reciprocal_scale, reciprocal_zp, 8, reciprocal_min_value, reciprocal_max_value) output_qtype = attributes.QuantResultTensorType(input_qtype.type, reciprocal_quantization) inputs_qtype = {InputName('data'): TensorValue(input_qtype)} calib_attrs = attributes.AwesomeCalibAttrs(quant=TensorValue(output_qtype), input_quant=inputs_qtype) return calib_attrs
[docs] def quantize_lrn(attrs: attributes.LRNAttrs, input_quant: Quantization, quant: Quantization) \ -> attributes.LRNQuantAttrs: """ Quantize LRN which is implemented based on quantized_local_response_normalization from ml_kernels repo: out = lut(square_sum(x)) * x where lut function is: lambda x: (bias + alpha / size * x) ** (beta) :param attrs: LRN attributes. :param input_quant: Quantization of input data :param quant: Layer quantization :return: Tuple[List[int], List[int], List[int]]. A tuple of (re-scaled input scales, corrected input zero points, right shifts) """ input_scale = input_quant.scale input_zp = input_quant.zero_point input_min = input_quant.min_val input_max = input_quant.max_val output_scale = quant.scale output_zp = quant.zero_point # Calculate scale and zero point for sum of input squares from input calibration parameters sq_min = 0 sq_max = max(input_min * input_min, input_max * input_max) * attrs.size sq_scale = compute_scale(asymmetry=True, layer_bits=quant.bits, min_val=sq_min, max_val=sq_max) sq_zp = compute_zero_point(asymmetry=True, layer_bits=quant.bits, min_val=sq_min, max_val=sq_max) # Calculate maximal integer value for sum of input squares max_sq = max((-128 - input_zp) * (-128 - input_zp), (127 - input_zp) * (127 - input_zp)) * attrs.size # LUT requantization parameters # Calculate maximum number of bits for lut_scale_q to avoid overflow lut_scale_bits = 31 - np.ceil(np.log2(max_sq)).astype(np.int32) lut_scale = sq_scale / (input_scale * input_scale) lut_sh = -calculate_normalization_shift(lut_scale, rounding=RoundType.UPWARD) + lut_scale_bits lut_scale_q = int(np.round(lut_scale * np.power(2.0, lut_sh))) # Calculate LUT func = lambda x: (attrs.bias + attrs.alpha / attrs.size * x) ** (-attrs.beta) lut_output_scale = max(func(sq_min), func(sq_max)) / np.iinfo(np.int8).max lut = ComputeLookupTable(func, input_quant=(1.0 / sq_scale, sq_zp), output_quant=(lut_output_scale, 0), input_type=np.int8, output_type=np.int8) lookup_table = lut.build() # Calculate requantization parameters for output = input * LUT layer_scale = output_scale * lut.output_scale / input_scale layer_sh = -calculate_normalization_shift(layer_scale, rounding=RoundType.UPWARD) + 15 # Use a maximum of 15 bits for layer_scale_q to avoid overflow layer_scale_q = int(np.round(layer_scale * np.power(2.0, layer_sh))) quant_attrs = attributes.LRNQuantAttrs(axis=attrs.axis, size=attrs.size, shape=attrs.shape, input_zp=input_zp, lut_scale=lut_scale_q, lut_zp_corr=int(sq_zp << lut_sh), lut_sh=int(lut_sh), output_scale=layer_scale_q, output_zp_corr=int(output_zp << layer_sh), output_sh=int(layer_sh), lookup_table=lookup_table) return quant_attrs
def _fractional_zero_requantization(out_scale: float, in_scale: float, out_max: int, out_zp: int, scale_bits: Optional[int] = None) -> FractionalZeroRequantization: """ Calculate fractional zero re-quantization. :param out_scale: Output quantization scale in SiMa's convention. :param in_scale: Input quantization scale in SiMa's convention. :param out_max: Theoretical worst-case maximum for the output value prior to re-quantization. :param out_zp: Output quantization zero point. :param scale_bits: Number of bits to use. Calculate if not provided. :return: Fractional zero re-quantization. """ # Calculate number of bits to be used for scale constant in the way that saturation is avoided. scale_bits = 30 - np.floor(np.log2(out_max)).astype(np.int32) if scale_bits is None else ( scale_bits) scale = out_scale / in_scale sh = -calculate_normalization_shift(scale, rounding=RoundType.UPWARD) + scale_bits # Limit shift value sh = min(sh, 23) scale_q = np.round(scale * np.power(2.0, sh)).astype(np.int32) if scale_q == 0: raise ValueError("Invalid scale. Please try using different calibration and/or " "quantization schemes.") zp_corr = out_zp << sh return FractionalZeroRequantization( int(scale_q), int(zp_corr), create_and_verify_narrowing(int(sh), RoundType.UPWARD, np.int8)) def _get_softmax_int16_requant(out_scale: float, in_scale: float, out_zp: int) ->\ Tuple[int, FractionalZeroRequantization]: """ Get fractional zero re-quantization for int16 Softmax. :param out_scale: Output quantization scale in SiMa's convention. :param in_scale: Input quantization scale in SiMa's convention. :param out_zp: Output quantization zero point. :return: Pre-shift value and Fractional zero re-quantization. """ lut_scale_bits = 15 scale = out_scale / in_scale lut_sh = -int(np.ceil(np.log2(scale))) + lut_scale_bits lut_scale_q = int(np.round(scale * np.power(2.0, lut_sh))) pre_sh = max(lut_sh - lut_scale_bits, 0) lut_sh -= pre_sh lut_zp_corr = (out_zp << lut_sh) requant_lut = FractionalZeroRequantization(lut_scale_q, lut_zp_corr, Narrowing(lut_sh, RoundType.TOEVEN, np.int16)) return pre_sh, requant_lut
[docs] def quantize_softmax(attrs: attributes.SoftmaxAttrs, input_quant: Quantization, quant: Quantization, intermediate_min_max: Dict[str, Tuple[float, float]], enable_int16: bool) \ -> attributes.SoftmaxQuantAttrs: """ Quantize Softmax which is implemented based on softmax implementation from ml_kernels repo: exp = lut_exp(x) # lut_exp(x) = exp(x) exp_sum_rec = lut_rec(np.sum(exp)) # lut_rec(x) = 1/x ofm = exp * exp_sum_rec :param attrs: Softmax attributes. :param input_quant: Quantization of input data :param quant: Layer quantization :param intermediate_min_max: Dict of intermediates min/max values. :param enable_int16: Whether to use int8 or int16 quantization. :return: Quantized Softmax attributes """ from ml_kernels.np_operators import compute_intermediate_lut import ml_kernels.requantization as requant axis = attrs.axis output_scale = quant.scale output_zp = quant.zero_point dtype = np.int16 if enable_int16 else np.int8 # Create exp(x) lookup table. exp_in_min = input_quant.min_val - input_quant.max_val exp_in_max = 0.0 # Correct exp_in_min value so calculated exp_zp remains in valid range. This range is being # calculated in the way that max_sum_q remains in 23 bit range (32 - 9(scale)) max_allowed_zp = int(np.ceil(127 - np.sqrt(2 ** 23))) if enable_int16: exp_in_min = exp_in_min symmetric = True else: max_allowed_min = _get_max_allowed_min(exp_in_max, max_allowed_zp, lambda x: float(np.exp(x)), lambda x: float(np.log(x))) exp_in_min = min(exp_in_min, max_allowed_min) symmetric = False lut_exp, exp_zp, exp_scale = compute_intermediate_lut(lambda x: float(np.exp(x)), exp_in_min, exp_in_max, dtype=dtype, symmetric=symmetric) # Calculate scale and zero point for rec input. n = attrs.input_shape[axis] if enable_int16: sum_exp = intermediate_min_max['sum_exp'] sum_exp_min, sum_exp_max = min(sum_exp), min(4 * max(sum_exp), float(n)) rec_in_min, rec_in_max = sum_exp_max / 255.0, sum_exp_max else: sum_exp_min, sum_exp_max = intermediate_min_max['sum_exp'] rec_in_min, rec_in_max = max(sum_exp_max / 255.0, sum_exp_min), sum_exp_max # Correct rec_in_min value so calculated rec_zp remains in valid range. This range is being # calculated in the way that max_sum_q remains in 23 bit range (32 - 9(scale)) max_allowed_rec_in_min = _get_max_allowed_min(rec_in_max, max_allowed_zp, lambda x: float(np.reciprocal(x)), lambda x: float(np.reciprocal(x))) rec_in_min = min(rec_in_min, max_allowed_rec_in_min) rec_in_scale = compute_scale(asymmetry=True, layer_bits=quant.bits, min_val=rec_in_min, max_val=rec_in_max) # Allow zp to be out of range. rec_in_zp = -(int(np.round(rec_in_min * rec_in_scale)) + np.iinfo(dtype).max + 1) # Calculate reciprocal(x) lookup table. lut_rec, rec_zp, rec_scale = compute_intermediate_lut(lambda x: float(np.reciprocal(x)), rec_in_min, rec_in_max, func_defined_at_zero=False, dtype=dtype) # Calculate lut re-quant params. # LUT input requnantization if enable_int16: lut_pre_sh, requant_lut = _get_softmax_int16_requant(out_scale=rec_in_scale, in_scale=exp_scale, out_zp=rec_in_zp) else: max_sum_q = (127 - exp_zp) * n requant_lut = _fractional_zero_requantization(out_scale=rec_in_scale, in_scale=exp_scale, out_max=max_sum_q, out_zp=rec_in_zp) lut_pre_sh = None # Update zp_correction to match kernel implementation. zp_correction = requant_lut.zp_correction - requant_lut.sc_correction * exp_zp * n requant_lut = requant.FractionalZeroRequantization(requant_lut.sc_correction, zp_correction, requant_lut.narrowing) # Calculate output re-quant params. if enable_int16: out_pre_sh, requant_output = _get_softmax_int16_requant(out_scale=output_scale, in_scale=exp_scale * rec_scale, out_zp=output_zp) else: max_out_q = (127 - exp_zp) * (127 - rec_zp) requant_output = _fractional_zero_requantization(out_scale=output_scale, in_scale=exp_scale * rec_scale, out_max=max_out_q, out_zp=output_zp) out_pre_sh = None quant_attrs = attributes.SoftmaxQuantAttrs(axis=axis, input_shape=attrs.input_shape, exp_zp=exp_zp, rec_zp=rec_zp, requant_lut=requant_lut, requant_output=requant_output, lookup_table_exp=lut_exp, lookup_table_rec=lut_rec, enable_int16=enable_int16, lut_input_pre_shift=lut_pre_sh, output_pre_shift=out_pre_sh) return quant_attrs
def _get_max_allowed_min(in_max: float, zp_limit: int, fun: Callable[[float], float], inv_fun: Callable[[float], float]) -> float: """ Calculate max allowed min value which will provide zp smaller than required limit on LUT creation by compute_intermediate_lut function. """ assert zp_limit < -128 lut_max = fun(in_max) max_allowed_min = lut_max / (1 - 255.0 / (zp_limit + 128)) if inv_fun(max_allowed_min) > in_max: max_allowed_min = lut_max * (1 - 255.0 / (zp_limit + 128)) return inv_fun(max_allowed_min) def _compute_scale_and_zp(min_val: float, max_val: float, bits: int = 8) -> Tuple[float, int]: """ Helper function for calculating scale and zero point, it is used in Layer and Instance Normalization operators quantization. Zero point is allowed to be out of -128, 127 range as it will yield better accuracy, and quantized value will later be clipped to -128, 127 range. """ scale = compute_scale(asymmetry=True, layer_bits=bits, min_val=min_val, max_val=max_val) # Allow zp to be out of -128, 127 range. zp = -(np.round(min_val * scale) + 128).astype(np.int32) return scale, zp
[docs] def quantize_layer_norm(attrs: attributes.LayerNormAttrs, input_quant: Quantization, quant: Quantization, intermediate_min_max: dict[str, tuple[float, float]]) \ -> attributes.LayerNormQuantAttrs: """ Quantize LayerNorm which is implemented based on layer norm implementation from ml_kernels repo: LayerNorm(input, axis, epsilon) = (input - m) / Sqrt(var + epsilon), where m = ReduceMean(input, axis, keepdims=True), var = ReduceMean((input - m) ** 2, axis, keepdims=True). Use LUT for reciprocal of the sqrt function. :param attrs: LayerNormAttrs attributes. :param input_quant: Quantization of input data. :param quant: Layer quantization. :param intermediate_min_max: Dict of intermediates min/max values. :return: Quantized LayerNormAttrs attributes. """ from ml_kernels.np_operators import compute_intermediate_lut axis = attrs.axis input_scale = input_quant.scale output_scale = quant.scale output_zp = quant.zero_point # Calculate mean re-quant params. n = attrs.input_shape[axis] mean_max = n * 127 requant_mean = _fractional_zero_requantization(out_scale=1, in_scale=n, out_max=mean_max, out_zp=0) # Create rsqrt lookup table. lut_fun: Callable[[float], float] = lambda x: 1 / float(np.sqrt(x + attrs.epsilon)) var_min, var_max = intermediate_min_max['var'] lut_in_min, lut_in_max = max(var_max / 128, var_min), var_max # Correct min value so calculated rsqrt_zp remains in valid range. This range is being # calculated in the way that max_outq remains in 23 bit range (32 - 9(scale)) max_allowed_exp_zp = int(np.ceil(127 - (2**23)/255)) inv_lut_fun: Callable[[float], float] = lambda x: (1 / x) ** 2 - attrs.epsilon max_allowed_min = _get_max_allowed_min(lut_in_max, max_allowed_exp_zp, lut_fun, inv_lut_fun) lut_in_min = min(lut_in_min, max_allowed_min) lut_rsqrt, rsqrt_zp, rsqrt_scale = compute_intermediate_lut(lut_fun, lut_in_min, lut_in_max) # Calculate lut input re-quant params. lut_in_scale, lut_in_zp = _compute_scale_and_zp(lut_in_min, lut_in_max, quant.bits) max_lut_in = 127 * n requant_lut_in = _fractional_zero_requantization(out_scale=lut_in_scale, in_scale=input_scale * input_scale * n, out_max=max_lut_in, out_zp=lut_in_zp, scale_bits=7) # Calculate output re-quant params. max_outq = 255 * (127 - rsqrt_zp) requant_output = _fractional_zero_requantization(out_scale=output_scale, in_scale=input_scale * rsqrt_scale, out_max=max_outq, out_zp=output_zp) return attributes.LayerNormQuantAttrs(axis=attrs.axis, input_shape=attrs.input_shape, zp_rsqrt=int(rsqrt_zp), lookup_table_rsqrt=lut_rsqrt, requant_mean=requant_mean, requant_lut_input=requant_lut_in, requant_output=requant_output)
[docs] def quantize_instance_norm(attrs: attributes.InstanceNormAttrs, input_quant: Quantization, mean_quant: Quantization, variance_quant: Quantization, quant: Quantization): """ Quantize Instance Normalization operator: (input - mean) / sqrt(variance + epsilon). Args: attrs: Instance Normalization attributes. input_quant: Quantization of the input data. mean_quant: Quantization of the mean input data. variance_quant: Quantization of the variance input data. quant: Layer quantization. Returns: Quantized Instance Normalization attributes. """ from ml_kernels.np_operators import compute_intermediate_lut input_scale = max(input_quant.scale, mean_quant.scale) output_scale = quant.scale output_zp = quant.zero_point # Create rsqrt lookup table. lut_fun: Callable[[float], float] = lambda x: 1 / float(np.sqrt(x + attrs.epsilon)) var_min, var_max = variance_quant.min_val, variance_quant.max_val lut_in_min, lut_in_max = max(var_max / 128, var_min), var_max # Correct min value so calculated rsqrt_zp remains in valid range. This range is being # calculated in the way that max_outq remains in 23 bit range (32 - 9(scale)) max_allowed_exp_zp = int(np.ceil(127 - (2 ** 23) / 255)) inv_lut_fun: Callable[[float], float] = lambda x: (1 / x) ** 2 - attrs.epsilon max_allowed_min = _get_max_allowed_min(lut_in_max, max_allowed_exp_zp, lut_fun, inv_lut_fun) lut_in_min = min(lut_in_min, max_allowed_min) lut_rsqrt, rsqrt_zp, rsqrt_scale = compute_intermediate_lut(lut_fun, lut_in_min, lut_in_max) # Calculate output re-quant params. max_outq = 255 * (127 - rsqrt_zp) requant_output = _fractional_zero_requantization(out_scale=output_scale, in_scale=input_scale * rsqrt_scale, out_max=max_outq, out_zp=output_zp) return attributes.InstanceNormQuantAttrs(attrs=attrs, lut_rsqrt=lut_rsqrt, zp_rsqrt=rsqrt_zp, requant_out=requant_output)
[docs] def quantize_rms_norm(attrs: attributes.RMSNormAttrs, input_quant: Quantization, quant: Quantization, intermediate_min_max: Dict[str, Tuple[float, float]], enable_lut_int16: bool) -> attributes.RMSNormQuantAttrs: """ Quantize RMS Normalization which is implemented based on rms norm implementation from ml_kernels repo: RMSNorm(x, axis, epsilon) = x / Sqrt(ReduceMean(x ** 2, axis, keepdims=True) + epsilon) Use LUT for reciprocal of the sqrt function. :param attrs: RMSNorm attributes. :param input_quant: Quantization of input data. :param quant: Layer quantization. :param intermediate_min_max: Dict of intermediates min/max values. :param enable_lut_int16: If True, quantize LUT to int16 otherwise to int8. :return: Quantized RMSNorm attributes. """ from ml_kernels.np_operators import compute_intermediate_lut # Define number of bits and data types lut_bits = 16 if enable_lut_int16 else 8 lut_dtype = np.int16 if enable_lut_int16 else np.int8 lut_out_bits = 16 if enable_lut_int16 else 8 max_lut_sh = 31 - lut_bits max_out_sh = 31 - lut_out_bits # Input and output quantization parameters input_scale = input_quant.scale out_scale = quant.scale y_zero_point = quant.zero_point # Reciprocal of the sqrt part of the rms norm function is approximated using LUT # Calculate LUT input range sq_min, sq_max = intermediate_min_max['reduce_mean'] input_min = input_quant.min_val input_max = input_quant.max_val x_sq_max = np.maximum(input_min * input_min, input_max * input_max) # The constants here were found experimentally while searching for the best range for the LUT # so that loss of information is minimized n = attrs.input_shape[-1] # k is used to multiply sq_max so that max input is moved towards the middle of the lut where the accuracy is better # 1/sqrt(x) is almost flattened near the maximum so the accuracy is bad there # for the int8, because of the small number of bits that doesn't improve accuracy, k is set to 1 k = 1 if lut_bits == 8 else 4 sq_max = np.minimum(k * sq_max, x_sq_max * n) # sq_max / 255 for lut_in_min is used so that min is not close to 0 since 1/sqrt(min) will be huge # if sq_min is bigger than sq_max / 255 than sq_min is used since it's further from 0 # on the other hand bigger sq_min means bigger lut_zp and for int16 that would mean zp that's outside int16 range # and that's why for int16 is always used sq_max / 255 if lut_bits == 8: lut_in_min, lut_in_max = max(sq_max / 255, sq_min), sq_max else: lut_in_min, lut_in_max = sq_max / 255, sq_max # Calculate LUT quantization parameters lut_in_scale = compute_scale(asymmetry=True, layer_bits=lut_bits, min_val=lut_in_min, max_val=lut_in_max) lut_in_zp = -(np.round(lut_in_min * lut_in_scale) + (2 ** (lut_bits - 1))).astype(np.int32) # Calculate LUT re-quantization parameters lut_in_scale_bits = 7 if lut_bits == 8 else 15 lut_in_req_scale = lut_in_scale / (input_scale * input_scale * n) sh = np.ceil(np.log2(lut_in_req_scale)).astype(np.int32) lut_in_sh = lut_in_scale_bits - sh lut_in_scale_q = int(np.round(lut_in_req_scale * np.power(2.0, lut_in_sh)).astype(np.int32)) lut_in_pre_sh = np.maximum(lut_in_sh - max_lut_sh, 0) lut_in_sh -= lut_in_pre_sh req_lut_in = TFLiteRequantization(lut_in_scale_q, int(lut_in_sh), RoundType.TOEVEN, int(lut_in_zp), lut_dtype) # Calculate LUT lut_fun = lambda x: 1 / np.sqrt(x + attrs.epsilon) lut_rsqrt, rsqrt_zp, rsqrt_scale = compute_intermediate_lut(lut_fun, lut_in_min, lut_in_max, dtype=lut_dtype) # Calculate output re-quantization parameters output_scale = out_scale / (input_scale * rsqrt_scale) sh = np.ceil(np.log2(output_scale)).astype(np.int32) out_scale_bits = 10 if lut_bits == 8 else 15 out_sh = out_scale_bits - sh out_scale_q = int(np.round(output_scale * np.power(2.0, out_sh)).astype(np.int32)) out_pre_sh = np.maximum(out_sh - max_out_sh, 0) out_sh -= out_pre_sh req_output = TFLiteRequantization(out_scale_q, int(out_sh), RoundType.TOEVEN, int(y_zero_point), np.int8) return attributes.RMSNormQuantAttrs(input_shape=attrs.input_shape, zp_ifm=input_quant.zero_point, lookup_table_rsqrt=lut_rsqrt, zp_rsqrt=rsqrt_zp, requant_lut_input=req_lut_in, requant_output=req_output, lut_input_pre_shift=int(lut_in_pre_sh), output_pre_shift=int(out_pre_sh), enable_lut_int16=enable_lut_int16)
[docs] def quantization_data_value_to_output_list(quantization: DataValue[Quantization]) \ -> Tuple[List[float], List[int], List[int], List[int], List[int]]: """ Convert a Data value of Quantization object(s) to lists of quantization-related values. This is used for interfacing to code that stores quantization information in five separate lists. :param: quantization: DataValue of Quantization object(s) to convert to quantization parameters Tuple. :return: Lists of scales, zero points, bits, minimum and maximum values. """ scales = [] zero_points = [] bitss = [] min_vals = [] max_vals = [] # Flatten the contents of quantizations into lists def extract_from(q: DataValue[Quantization]): if isinstance(q, TensorValue): scales.append(q.value.scale) zero_points.append(q.value.zero_point) bitss.append(q.value.bits) min_vals.append(q.value.min_val) max_vals.append(q.value.max_val) elif isinstance(q, TupleValue): for element in q.elements: extract_from(element) else: raise TypeError("Invalid type for DataValue") extract_from(quantization) return scales, zero_points, bitss, min_vals, max_vals
def _fix_narrowing(narrowing: Narrowing) -> Narrowing: """ Change the data type of the right_shift array, if it is present, to uint8. """ shift = narrowing.shift if isinstance(shift, np.ndarray): return dataclasses.replace(narrowing, shift=shift.astype(np.uint8)) # else return narrowing
[docs] def fix_requantization(requantization: BaseRequantization[np.ndarray]) -> BaseRequantization[np.ndarray]: """ Change the data type of the right_shift array, if it is present, to uint8. """ if isinstance(requantization, ml_kernels.requantization.TFLiteRequantization): shift = requantization.shift if isinstance(shift, np.ndarray): return dataclasses.replace(requantization, shift=shift.astype(np.uint8)) # else return requantization elif isinstance(requantization, ml_kernels.requantization.ArithFoldedRequantization): return dataclasses.replace(requantization, narrowing=_fix_narrowing(requantization.narrowing)) elif isinstance(requantization, ml_kernels.requantization.FractionalZeroRequantization): return dataclasses.replace(requantization, narrowing=_fix_narrowing(requantization.narrowing)) raise TypeError("Unrecognized requantization type")
[docs] def cast_calibration_inputs(values: List[np.ndarray], cast: QuantizationCast): """ Quantizes a list of tensors according to casts. Identity cast returns the original values. """ if isinstance(cast, IdentityCast): return values elif isinstance(cast, QuantCast): return [quantize_tensor(tensor, [cast.scale], [cast.zero_point], [cast.num_bits]) for tensor in values] elif isinstance(cast, DequantCast): return [dequantize_tensor(tensor, [cast.scale], [cast.zero_point]) for tensor in values] elif isinstance(cast, RequantCast): requant = create_requantization_from_cast(cast) return [ml_kernels.requantization.requantize(tensor, requant) for tensor in values] elif isinstance(cast, ConvertCast): output_type = cast.out_type return [tensor.astype(output_type.numpy_type()) for tensor in values] else: raise RuntimeError(f"This cast is not processed {cast}")
[docs] def create_requantization_from_cast(cast: RequantCast) -> BaseRequantization[np.ndarray]: """ Get the Requantization that implements the given cast. :param cast: Cast to perform :return: Requantization """ rq_input = cast.get_input_quantization() rq_output = cast.get_output_quantization() out_dtype = np.int8 if rq_output.bits == 8 else np.int16 if cast.requant_method == RequantMethod.arith_folded: shift = power_of_2_requantization(rq_input, rq_output) if shift >= 0: requant = narrowing_requantization(shift, RoundType.TOEVEN, out_dtype) else: # Scale factor is greater than 1, equivalent to a left shift. Use a multiplication. assert shift >= -8, "Scale factor is too large for requantization" requant = TFLiteRequantization(sc_correction=1 << -shift, shift=0, rounding=RoundType.TRUNC, zp_correction=0, out_dtype=out_dtype) else: assert cast.requant_method in (RequantMethod.fractional_zero, RequantMethod.scaled_fz) sc_corr, zp_corr, shift = requantization(rq_input, rq_output, 32) if cast.requant_method == RequantMethod.scaled_fz: # The input and output quantization should have been chosen to ensure zp_corr == 0. # Allow error up to +/- (0.5 * 2**shift), which would be +/- 0.5 in the output. zp_corr_bound = 1 << max(shift - 1, 0) assert abs(zp_corr) <= zp_corr_bound, \ "Scale-only requantization is not suitable for producing the wanted quantization" zp_corr = 0 requant = FractionalZeroRequantization(sc_corr, zp_corr, Narrowing(shift, RoundType.TOEVEN, out_dtype)) return requant