#########################################################
# Copyright (C) 2020 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Joey Chou
#########################################################
import dataclasses
import math
import numpy as np
from enum import Enum
from typing import List, Union, Tuple, Optional, Callable, cast, Dict, TypeVar, Any
from sima_utils.logging import sima_logger
import afe.ir.attributes as attributes
from afe.ir.defines import (
QuantizedTensor, QuantizedParam, NodeName, Float, InputName, DataValue, Quantization,
RequantMethod, TensorValue, TupleValue, QuantizedTensorNew, QuantizedTensorInt16, QuantizationCast, IdentityCast,
QuantCast, DequantCast, RequantCast, ConvertCast
)
from afe.ir.tensor_type import ScalarType, scalar_is_integral
from afe.ir.utils import (
transpose_axis_to_the_last, create_and_verify_narrowing
)
from afe.core.configs import QuantizationConfigs
from ml_kernels.math_helpers import RoundType
from ml_kernels.np_operators import ideal_udf
from ml_kernels.requantization import Narrowing, BaseRequantization, FractionalZeroRequantization, \
TFLiteRequantization, ArithFoldedRequantization, narrow, narrowing_requantization
import ml_kernels.requantization
from ml_kernels.udf import ComputeLookupTable
_TENSOR = TypeVar("_TENSOR")
_ZERO_KERNEL_THRESHOLD = 0.001
[docs]
class QNNDtype(str, Enum):
"""Data types used in QNN operations"""
[docs]
DTYPE_BOUNDS = {QNNDtype.INT8: (-128, 127),
QNNDtype.UINT8: (0, 255),
QNNDtype.INT32: (-2147483648, 2147483647)}
def _override_zero_scale(scale: Union[float, np.ndarray, List]) -> Union[float, np.ndarray, List]:
"""
If scale is equal to 0, set it to 1.
Some part of the code will not behave properly is scale is 0. in that case we set the scale to 1.
"""
if isinstance(scale, np.ndarray):
return np.where(scale == 0, 1.0, scale)
elif isinstance(scale, List):
return [sc if sc != 0 else 1.0 for sc in scale]
else:
assert isinstance(scale, float)
return scale if scale != 0 else 1.0
[docs]
def round_op(x: float, rounding_type: RoundType = RoundType.TOEVEN) -> float:
"""
Rounding to the nearest larger integer
:param x: A float32 number to be rounded
return: Rounded result
"""
if rounding_type == RoundType.UPWARD:
return np.floor(x + 0.5) # round to +inf
elif rounding_type == RoundType.TOEVEN:
return np.round(x) # round to even
elif rounding_type == RoundType.TONEAREST:
return np.sign(x) * np.floor(np.abs(x) + 0.5) # round away from 0
elif rounding_type == RoundType.TRUNC:
return np.trunc(x)
else:
raise ValueError(f"Expected {[t.name for t in RoundType]} for the rounding type, got {rounding_type}")
[docs]
def calculate_normalization_shift(scale: Union[float, np.ndarray],
rounding: RoundType = RoundType.TRUNC) -> Union[float, np.ndarray]:
"""
Calculate the number of shifts to normalize a scale.
The original scale will be normalized, depending on the rounding type, after dividing (2**shift).
"""
from afe.ir.quantization_conv import decompose_power_of_2
if isinstance(scale, np.ndarray):
exp, _ = decompose_power_of_2(scale, rounding=rounding)
shift = exp.astype(int)
else:
exp, _ = decompose_power_of_2(np.asarray(scale), rounding=rounding)
shift = int(exp)
return shift
[docs]
def get_bound(bits: int, signed: bool = True) -> int:
if signed:
bound = np.power(2., bits - 1)
else:
bound = np.power(2., bits)
return bound.astype(QuantizedParam)
[docs]
def clip_to_targeted_range(x: Union[int, np.ndarray], bits: int,
restricted_range: bool = False) -> Union[int, np.ndarray]:
"""
Clip the x with targeted range determined by the given bit number.
:param x: Numpy array or int
:param bits: Number of bits used to determine the min and max number
:param restricted_range: If true, the abs(a_min) == abs(a_max)
"""
bound = get_bound(bits)
a_min = -bound
a_max = bound - 1
if restricted_range:
a_min += 1
return np.clip(x, a_min, a_max)
[docs]
def compute_scale(asymmetry: bool, layer_bits: int,
min_val: float, max_val: float,
include_real_zero_point: bool = False,
) -> float:
"""
Compute a linear quantization scale for mapping the range (min_val, max_val) onto the quantized integer range
determined by layer_bits, include_real_zero_point, and asymmetry.
The computed scale is the reciprocal of the scale in TFLite's convention.
:param asymmetry: If true, do asymmetric quantization.
:param layer_bits: Number of bits used for quantization.
:param min_val: Minimum value.
:param max_val: Maximum value.
:param include_real_zero_point: If True, force the float dynamic range
covering zero.
return: Computed scale s such that real numbers r are converted to integers q by the formula q = round(s * r).
"""
if include_real_zero_point:
min_val = min(0, min_val)
max_val = max(0, max_val)
if asymmetry:
# -128 ~ 127
# Measure the size of the range [MIN, MAX]
max_integer = 2 ** layer_bits - 1
max_float = max_val - min_val
else:
# Measure the larger of the ranges [MIN, 0] and [0, MAX]
max_integer = 2 ** (layer_bits - 1) - 1
max_float = max(abs(min_val), abs(max_val))
# Workaround for division by 0. In this case the range is a single point
# and there is no meaningful scale factor.
if max_val == min_val:
max_float = max_val
# 1/S = max_integer / max_float
if max_float == 0:
# The range is a single point and there is no meaningful
# scale factor. Use special value 0 to signal this.
return 0.
return float(abs(max_integer / max_float))
[docs]
def compute_zero_point(asymmetry: bool, layer_bits: int,
min_val: float, max_val: float,
restricted_range: bool = False
) -> int:
"""
Given min and max value, compute the zero point.
:param asymmetry: If true, do asymmetric quantization.
:param layer_bits: Number of bits used for quantization.
:param min_val: Minimum value.
:param max_val: Maximum value.
:param restricted_range: If True, the dynamic range will be equal
at negative and positive side.
return: Zero point.
"""
if not asymmetry:
return 0
if min_val == 0 and max_val != 0:
''' Somehow when the min_values == 0 and the zero point is -(2^layer_bits - 1), it decrease the accuracy a lot.
So here we make sure the zero point is -(2^layer_bits) in such case.
'''
return int(-np.power(2, layer_bits - 1))
if min_val == 0 and max_val == 0:
''' If all values are zero, then zp is also 0.
'''
return 0
fp_zp = -0.5 * (min_val + max_val) * compute_scale(
asymmetry, layer_bits, min_val, max_val,
include_real_zero_point=True)
if not restricted_range:
# Bias toward negative side by 1
fp_zp -= (1. / 2)
zp = int(round_op(fp_zp))
return int(clip_to_targeted_range(zp, layer_bits, restricted_range))
[docs]
def significant_bits_signed(n: int) -> int:
"""
Get the smallest signed integer bit width that can represent the given integer.
> significant_bits_signed(-129) = 9
> significant_bits_signed(-128) = 8
> significant_bits_signed(127) = 8
> significant_bits_signed(128) = 9
"""
# Calculate bit-width in sign-magnitude representation
bits = 1 + n.bit_length()
# Convert to bit-width in twos complement representation
if n < 0 and (-n & (-1 - n)) == 0:
# n is a negative power of 2. Its size is 1 bit less than sign-magnitude representation.
bits -= 1
return bits
[docs]
def compute_power_of_2_scale_and_shift(scale: Union[float, np.ndarray],
input_bit: int,
output_bit: int
) -> Union[Tuple[int, int], Tuple[np.ndarray, np.ndarray]]:
"""
Given a float scale or a vector of scale and quantized bit number for input and output,
return a quantized scale and right shift
:param scale: Union[float, np.ndarray]
:param input_bit: int. Number of bit used for input quantization
:param output_bit: int. Number of bit used for output quantization
:return: Union[Tuple[int, int], Tuple[np.ndarray, np.ndarray]. Tuple of (scale, right shift)
"""
right_shift = -calculate_normalization_shift(scale)
quantized_scale = (scale * np.power(2., right_shift)).astype(np.float32)
# Avoid overflow
input_bound = get_bound(input_bit)
output_bound = get_bound(output_bit)
if isinstance(quantized_scale, np.ndarray):
assert quantized_scale.ndim == 1
assert scale.shape == quantized_scale.shape
for i in range(len(quantized_scale)):
if round_op(quantized_scale[i] * (input_bound // 2)) == output_bound:
right_shift[i] -= 1
quantized_scale[i] /= 2.
else:
if round_op(quantized_scale * (input_bound // 2)) == output_bound:
right_shift -= 1
quantized_scale /= 2.
right_shift += (input_bit - 1)
quantized_scale *= input_bound
return quantized_scale, right_shift
[docs]
def compute_weight_scale(weight: np.ndarray, bits: int) -> float:
"""
Compute weight scale. Weights are always quantized symmetrically.
:param weight: Weight tensor.
:param bits: Number of bits used to quantize weight.
return: Scale of weight.
"""
return compute_scale(asymmetry=False, layer_bits=bits, min_val=np.min(weight), max_val=np.max(weight),
include_real_zero_point=True)
[docs]
def compute_weight_scale_per_channel(weight: np.ndarray, bits: int) -> np.ndarray:
"""
Compute per-channel weight scales. The expected layout of weight is AwesomeConvWeightLayout.
:param weight: Weight tensor in AwesomeConvWeightLayout format.
:param bits: Number of bits used to quantize weight.
return: An array of scales of weight.
"""
scales = np.zeros((weight.shape[-1])).astype(float)
for i in range(weight.shape[-1]):
scales[..., i] = compute_weight_scale(weight[..., i], bits)
return scales
[docs]
def linear_scale(input: np.ndarray, scale: float, bits: int, clip: bool = True) -> np.ndarray:
"""
Linear scale the input based on the scale. Clip the scaled input based on the bit number
:param input: A numpy array.
:param scale: A scale factor that used to scale the input to a target range.
:param bits: Number of bit used to clip the scaled input.
:param clip: If true, clip the linear scale result to the given dynamic range.
return: Scaled input.
"""
res = input * scale
if clip:
res = round_op(res)
res = clip_to_targeted_range(res, bits)
return res.astype(QuantizedTensor)
[docs]
def linear_scale_per_channel(input: np.ndarray, scale: np.ndarray, bits: int, clip: bool = True) -> np.ndarray:
"""
Linear scale the input based on the scale. Clip the scaled input based on the bit number
The output channel has to be at the last dimension.
:param input: A numpy array.
:param scale: A numpy array of scale factors that used to scale the input to a different
target ranges in different channels.
:param bits: Number of bit used to clip the scaled input.
:param clip: If true, clip the linear scale results to the given dynamic range.
return: Scaled input.
"""
res = np.zeros(input.shape, dtype=QuantizedTensor)
for i in range(input.shape[-1]):
res[..., i] = linear_scale(input[..., i], scale[i], bits, clip)
return res
[docs]
def linear_quantize(input: np.ndarray, scale: float, zp: int, bits: int) -> np.ndarray:
"""
quantized_input = (input / S) + zero_point.
:param input: A numpy array.
:param scale: scale = (1/S) in the above equation.
:param zp: Zero point of the quantized input.
:param bits: Number of bit used to clip the scaled input.
return Quantized input.
"""
scale = _override_zero_scale(scale)
delta = -scale if scale < 0 else 1 / scale
rounded = round_op(input / delta) + zp
res = clip_to_targeted_range(rounded, bits)
out_type = QuantizedTensorInt16 if bits == 16 else QuantizedTensorNew
return np.array(res.astype(out_type))
[docs]
def linear_quantize_with_quantization(input: np.ndarray, quantization: Quantization) -> np.ndarray:
"""
Apply a quantization to a floating-point tensor to produce a quantized tensor.
:param input: Floating-point tensor
:param quantization: Quantization to apply
:return: Quantized tensor
"""
return linear_quantize(input, quantization.scale, quantization.zero_point, quantization.bits)
[docs]
def quantize_value(value: Any, q: DataValue[Optional[Quantization]]) -> Any:
"""
Quantize a value according to the given quantization.
Values consist of arrays and tuples.
:param value: Value to quantize. It must consist of numpy arrays and tuples.
:param q: Quantization of the value. None means that the value is not quantized
and so it will be returned unchanged.
:return: Quantized value. It has the same tuple structure as the input.
"""
if isinstance(q, TensorValue):
assert isinstance(value, np.ndarray)
if q.value is None:
# Not quantized
return value
else:
# Quantized
return linear_quantize_with_quantization(value, q.value)
elif isinstance(q, TupleValue):
assert isinstance(value, (list, tuple))
assert len(value) == len(q.elements)
return tuple(quantize_value(v, p) for v, p in zip(value, q.elements))
[docs]
def dequantize_value(value: Any, q: DataValue[Optional[Quantization]]) -> Any:
"""
Dequantize a value according to the given quantization.
Values consist of arrays and tuples.
:param value: Value to dequantize. It must consist of numpy arrays and tuples.
:param q: Quantization of the value. None means that the value is not quantized
and so it will be returned unchanged.
:return: Dequantized value. It has the same tuple structure as the input.
"""
if isinstance(q, TensorValue):
assert isinstance(value, np.ndarray)
if q.value is None:
# Not quantized
return value
else:
# Quantized
return dequantize(value, 1 / q.value.scale, q.value.zero_point)
elif isinstance(q, TupleValue):
assert isinstance(value, (list, tuple))
assert len(value) == len(q.elements)
return tuple(dequantize_value(v, p) for v, p in zip(value, q.elements))
[docs]
def get_zero_kernel_mask_per_channel(weight: np.ndarray, threshold: float) -> np.ndarray:
"""
Return the mask of zero kernel. The kernel layout of weight must be in AwesomeConvWeightLayout.
:param weight: Weights for convolution in AwesomeConvWeightLayout layout.
:param threshold: If the sum of kernel's absolute value is smaller than the threshold, the kernel will
be treated as a zero kernel.
return: Mask of zero kernel. True means the kernel is a zero kernel.
"""
mask = []
for i in range(weight.shape[-1]):
mask.append(np.sum(abs(weight[..., i])) <= threshold)
return np.array(mask)
[docs]
def dequantize(input: np.ndarray, scale: float, zp: int) -> np.ndarray:
"""
Original equation:
quantized_input = (input / S) + zero_point.
Reverse it to get dequantize equation:
dequantized input = (quantized_input - zero_point) * S
:param input: A numpy array.
:param scale: scale = (1 / S) in the above equation.
:param zp: Zero point of the quantized input.
return Dequantized input.
"""
return ((input.astype(np.int32) - zp) * scale).astype(Float)
def _requantize_half_up(input: np.ndarray, bits: int, right_shift: int, zp: Optional[int] = None) -> np.ndarray:
"""
Requantize a quantized tensor to another quantization domain
:param input: A numpy array.
:param bits: Number of bit used to clip the scaled input.
:param right_shift: Number of bit shifted to the right. This acts as a hardware friendly multiple of 2 scale.
:param zp: Zero point of the quantized input.
return: Requantized tensor in int format.
"""
HALF = 1 << (right_shift - 1) if right_shift > 0 else 0
res = (input.astype(QuantizedTensor) + HALF) >> right_shift
if zp is not None:
res += zp
return clip_to_targeted_range(res, bits).astype(QuantizedTensor)
def _requantize_even(input: np.ndarray, bits: int, right_shift: int, zp: Optional[int] = None) -> np.ndarray:
"""
Requantize a quantized tensor to another quantization domain with the round-to-even mode.
:param input: A numpy array.
:param bits: Number of bit used to clip the scaled input.
:param right_shift: Number of bit shifted to the right. This acts as a hardware friendly multiple of 2 scale.
:param zp: Zero point of the quantized input.
return: Requantized tensor in int format.
"""
res = input.astype(np.int32)
res = res / (2 ** right_shift)
res = np.round(res)
if zp is not None:
res += zp
return clip_to_targeted_range(res, bits).astype(QuantizedTensor)
[docs]
def requantize(data: np.ndarray,
bits: int, right_shifts: Union[int, np.ndarray],
zp: Optional[int] = None,
per_channel: bool = False,
axis: int = -1,
rounding_type: RoundType = RoundType.UPWARD, *,
result_type: ScalarType = ScalarType.int8) -> np.ndarray:
"""
Requantize a quantized tensor to another quantization domain
:param data: A numpy array.
:param bits: Number of bit used to clip the scaled input.
:param right_shifts: A numpy array. Each ouput channel has a number of bit shifted to the right.
This acts as a hardware friendly multiple of 2 scale.
:param zp: Zero point of the quantized input.
:param per_channel: Default is False. If True, each output channel has one right_shift.
:param result_type: Numeric type of requantized tensor.
return: Requantized tensor in chosen numeric type.
"""
# Only support signed integer types up to 32 bit
assert result_type in (ScalarType.int8, ScalarType.int32)
np_result_type = result_type.numpy_type()
if rounding_type == RoundType.TOEVEN:
requantize_fn = _requantize_even
else:
requantize_fn = _requantize_half_up
if not per_channel:
return requantize_fn(data, bits, right_shifts, zp).astype(np_result_type)
right_shifts = cast(np.ndarray, right_shifts)
if axis != -1 and axis != data.ndim - 1:
data = transpose_axis_to_the_last(data, axis)
res = np.zeros(data.shape, dtype=QuantizedTensor)
for i in range(data.shape[-1]):
res[..., i] = requantize_fn(data[..., i], bits, right_shifts[i], zp)
if axis != -1 and axis != data.ndim - 1:
res = transpose_axis_to_the_last(res, axis)
return res.astype(np_result_type)
[docs]
def float_requantization(input_quantization: Quantization, output_quantization: Quantization) -> Tuple[float, float]:
"""
Calculate floating-point correction parameters to requantize integer data using
floating-point intermediate values.
It returns S and Z such that data can be requantized by the calculation:
quantized_output = round(S * float(quantized_input) + Z)
:param input_quantization: Quantization of input data
:param output_quantization: Quantization of output data
:return: Requantization scale correction and zero point correction
"""
# We have seen multiply by zero in mobilefacedet_v1_mxnet and colorization-siggraph.onnx (SWMLA-4648)
# If input is zero, return 0 for scale correction and 0 for zp correction
if input_quantization.scale == 0:
sima_logger.sima_log_warning("Returning zero for requantization scale correction because of zero input. "
"Please check the model for nodes that are algebraically equivalent to "
"a zero constant (for example, multiply with a zero constant.")
return 0.0, 0.0
sc_correction: float = output_quantization.scale / input_quantization.scale
assert sc_correction > 0
zp_correction: float = output_quantization.zero_point - sc_correction * input_quantization.zero_point
return sc_correction, zp_correction
def _is_close_to_pow2(n: float) -> bool:
"""
Return true if the number is close to a power of 2.
The tolerance is roughly 1 part in 256.
"""
log_remainder = math.log2(n) % 1
return log_remainder < 0.05 or (1 - log_remainder) < 0.05
[docs]
def power_of_2_requantization(input_quantization: Quantization, output_quantization: Quantization) -> int:
"""
Calculate a shift factor to requantize data by a power of 2 in integer arithmetic.
This should only be used if the input and output quantization were chosen
for power of 2 requantization. It is not a good approximation in general.
It returns a shift such that data can be requantized by the calculation:
quantized_output = quantized_input >> shift
The shift should use rounding to nearest, with any tie-breaking method.
:param input_quantization: Quantization of input data
:param output_quantization: Quantization of output data
:param bits: Integer precision of temporary values
:return: Amount to shift right. May be negative.
"""
sc_correction, zp_correction = float_requantization(input_quantization, output_quantization)
shift = -calculate_normalization_shift(sc_correction, rounding=RoundType.TONEAREST)
# This method does not do zero point correction; effectively the correction is 0.
# Calculate the zero point correction and verify that it would round to zero after
# shifting.
zp_correction = round(zp_correction * 2**shift)
assert abs(zp_correction) <= 0.5 * (1 + 2**shift), \
"Power-of-2 requantization is not suitable for producing the wanted quantization"
return shift
def _rescale_by_common_factor_of_two(scale: int, zp: int, shift: int) -> Tuple[int, int, int]:
"""
Eliminate the common factor of scale, zp and shift using the find-first-set technique
(x & -x) to find the max power of 2 factor of int_sc_correction and int_zp_correction.
:param scale: Integer requantization scale.
:param zp: Requantization zero point.
:param shift: Requantization shift.
:return:
"""
shift_scale = 2 ** shift
common_power_of_2_factor = min(shift_scale, scale & -scale)
if zp != 0:
common_power_of_2_factor = min(
common_power_of_2_factor, zp & -zp)
scale //= common_power_of_2_factor
zp //= common_power_of_2_factor
shift -= common_power_of_2_factor.bit_length() - 1
return scale, zp, shift
[docs]
def requantization(input_quantization: Quantization, output_quantization: Quantization,
bits: int = 32, *, sc_correction_bits: int = 32) -> Tuple[int, int, int]:
"""
Calculate correction factors to requantize data in integer arithmetic.
It returns S, Z, and shift such that data can be requantized by the calculation:
quantized_output = ((S * quantized_input) + Z) >> shift
The shift should use rounding to nearest, with any tie-breaking method.
:param input_quantization: Quantization of input data
:param output_quantization: Quantization of output data
:param bits: Integer precision of temporary values
:param sc_correction_bits: Integer precision of the scale correction.
The returned scale correction, taken as a signed integer, will not exceed this many bits.
:return: Requantization scale correction, zero point correction, and right shift
"""
assert output_quantization.bits <= input_quantization.bits
assert sc_correction_bits > 1
if input_quantization == output_quantization:
# Requantization is not needed
return 1, 0, 0
# Calculate correction factors in floating point
sc_correction, zp_correction = float_requantization(input_quantization, output_quantization)
if sc_correction == 0:
assert zp_correction == 0, \
f"Expect zero for zp correction when scale correction is zero, but got {zp_correction}"
return 0, 0, 0
# Find value range of intermediate results (S * quantized_input)
# and (S * quantized_input + Z), as they would be if the floating-point
# correction factors are used directly
input_qmin: float = input_quantization.scale * input_quantization.min_val + input_quantization.zero_point
input_qmax: float = input_quantization.scale * input_quantization.max_val + input_quantization.zero_point
intermediate_qmin = sc_correction * input_qmin + min(zp_correction, 0)
intermediate_qmax = sc_correction * input_qmax + max(zp_correction, 0)
intermediate_abs_max_value = max(abs(intermediate_qmin), abs(intermediate_qmax))
# Value range of the integer type that is used for calculations.
# We pretend it's a symmetric range [-2^(b-1) ... 2^(b-1)] to simplify the math.
int_max = 2**(bits-1)
# Find scale factor for quantizing the correction factors. Choose a factor
# that scales up intermediate_abs_max_value to fit the integer value range.
assert intermediate_abs_max_value <= int_max, \
"Potential overflow was detected in requantization arithmetic"
base_shift_scale = int_max / intermediate_abs_max_value
shift_bits = calculate_normalization_shift(base_shift_scale)
# Ensure that the output has at least as many significant bits as the input
output_bits = min(output_quantization.bits, _symmetric_quant_bits(output_quantization.scale,
output_quantization.min_val,
output_quantization.max_val))
min_preserved_bits = min(16, output_bits)
shift_bits = min(shift_bits, bits - min_preserved_bits)
# Ensure that the scale correction is representable in sc_correction_bits
max_scale_bits = (1 << (sc_correction_bits - 1)) - 1 # Maximum representable value
estimated_sc_correction = sc_correction * (2**shift_bits)
shift_adjustment = max(0, math.ceil(math.log2(estimated_sc_correction / max_scale_bits)))
shift_bits -= shift_adjustment
assert shift_bits >= 0, f"Cannot represent scale correction with {sc_correction_bits} bits"
shift_scale = 2**shift_bits
# Quantize sc_correction and zp_correction with the chosen scale factor
int_sc_correction = round(sc_correction * shift_scale)
int_zp_correction = round(zp_correction * shift_scale)
# Zero scale correction factor means that the output would be constant zero,
# all information would be lost by requantization
assert int_sc_correction > 0, "Error in quantization"
# Eliminate the common factor of scale, zp and shift.
return _rescale_by_common_factor_of_two(int_sc_correction, int_zp_correction, shift_bits)
[docs]
def requantization_tflite(input_quantization: Quantization, output_quantization: Quantization) -> Tuple[int, int, int]:
"""
Calculate correction factors to do TFLite requantization.
It returns S, Z, and shift such that data can be requantized by the calculation:
quantized_output = ((S * quantized_input) >> shift) + Z
The shift should use rounding to nearest, with any tie-breaking method.
The product (S * quantized_input) is assumed not to overflow. It is
designed for a datapath that calculates this product in 64-bit precision.
:param input_quantization: Quantization of input data. The input data's zero point must be 0.
:param output_quantization: Quantization of output data
:return: Requantization scale correction, zero point correction, and right shift
"""
assert output_quantization.bits <= input_quantization.bits
assert input_quantization.zero_point == 0
if input_quantization == output_quantization:
# Requantization is not needed
return 1, 0, 0
# Calculate correction factors in floating point
sc_correction, zp_correction = float_requantization(input_quantization, output_quantization)
assert zp_correction == output_quantization.zero_point
# Quantize sc_correction to 8 bits
from afe.ir.quantization_conv import decompose_power_of_2
exponent, frac_sc_correction = decompose_power_of_2(np.array(sc_correction), RoundType.UPWARD)
shift = int(-exponent.item() + 8)
int_sc_correction = round(frac_sc_correction.item() * (2**8))
# Shift value should be strictly smaller than the number of bits in (S * quantized_input)
assert 0 <= shift < input_quantization.bits + 8, "Requantization scale factor is out of range"
# Eliminate the common factor of scale and shift
int_sc_correction, _, shift = _rescale_by_common_factor_of_two(int_sc_correction, 0, shift)
return int_sc_correction, output_quantization.zero_point, shift
###################
# Runtime utilities
###################
[docs]
def is_quantized(data: np.ndarray) -> bool:
return data.dtype in [np.int8, np.int16, np.int32]
[docs]
def dequantize_tensor(data: Union[List[np.ndarray], Tuple[np.ndarray, ...], np.ndarray],
scales: List[float], zps: List[int]) \
-> Union[List[np.ndarray], Tuple[np.ndarray, ...], np.ndarray]:
"""
Dequantize tensor. A tensor can be a List[int], a Tuple[np.ndarray, ...], or a np.ndarray.
"""
scales = _override_zero_scale(scales)
if isinstance(data, (tuple, list)):
dequantized_data = []
for _data, _scale, _zp in zip(data, scales, zps):
if is_quantized(_data):
dequantized_data.append(dequantize(_data, 1. / _scale, _zp))
else:
dequantized_data.append(_data)
# Cast to tuple
if isinstance(data, tuple):
dequantized_data = tuple(dequantized_data)
elif is_quantized(data):
dequantized_data = dequantize(data, 1. / scales[0], zps[0])
else:
return data
# Make sure it returns a correct type
assert type(dequantized_data) == type(data)
return dequantized_data
[docs]
def quantize_tensor(data: Union[Tuple[np.ndarray, ...], np.ndarray],
scales: List[Union[float, List[float]]],
zps: List[Union[int, List[int]]],
layer_bits: List[Union[int, List[int]]],
) -> Union[Tuple[np.ndarray, ...], np.ndarray]:
"""
Quantize tensor. A tensor can be Tuple[np.ndarray, ...] or a np.ndarray.
"""
if isinstance(data, tuple):
quantized_data = []
for i, _data in enumerate(data):
if not is_quantized(_data):
quantized_data.append(linear_quantize(_data, scales[i], zps[i], layer_bits[i]))
else:
quantized_data.append(_data)
data = tuple(quantized_data)
elif not is_quantized(data):
data = linear_quantize(data, scales[0], zps[0], layer_bits[0])
return data
#######################
# Operator Quantization
#######################
[docs]
def quantize_alpha(alpha: np.ndarray, bits: int = 8) -> Tuple[np.ndarray, int]:
"""
Quantize the alpha for PreluOp
:param alpha: Alpha
:param bits: Number of bits used for quantization
:return: Quantized alpha, shift value
"""
alpha_scale = compute_weight_scale(alpha, bits)
right_shift = calculate_normalization_shift(alpha_scale)
# Shift factor needs to be limited to 31. For alpha values which correspond to
# larger shift (range very close to zero), alpha values are quantized to zero values,
# which effectively results in PReLU operator becoming a regular ReLU operator.
right_shift = min(right_shift, 31)
quant_alpha = np.around(alpha * np.power(2.0, right_shift)).astype(np.int8)
return quant_alpha, right_shift
[docs]
def quantize_add_subtract(is_subtract: bool,
input_scales: List[float],
input_zps: List[int],
scale: float,
zero_point: int,
layer_bits: int,
in1_scale_const: int = 1, in2_scale_const: int = 1) -> Tuple[List[int], int, int]:
"""
Quantize the add/subtact operator
:param is_subtract: If True function is used to quantize subtract
operator, otherwise add operator.
:param input_scales: Scales of the input nodes.
:param input_zps: Zero points of the input nodes.
:param scale: Scale of the current node.
:param zero_point: Zero point of the current node.
:param layer_bits: Number of bits used for quantization.
:param attrs: AwesomeAttributes class
:param activ_attrs: Activation function used in case of composite operations.
:param in1_scale_const: Const to be folded in 1st input scale.
:param in2_scale_const: Const to be folded in 2nd input scale.
"""
# Only scalar constants are supported at this point
assert np.isscalar(in1_scale_const)
assert np.isscalar(in2_scale_const)
in1_zp, in2_zp = input_zps
layer_zp = zero_point
in1_scale, in2_scale = input_scales
layer_scale = _override_zero_scale(scale)
in1_out_scale = layer_scale / in1_scale * in1_scale_const if in1_scale != 0. else layer_scale / in1_scale_const
in2_out_scale = layer_scale / in2_scale * in2_scale_const if in2_scale != 0. else layer_scale / in2_scale_const
# Normally, with non-zero inputs, select the max to normalize
# However, scale=1 and zp=0 are hard coded for zero tensors
# If one input is zero, select non-zero input for normalization
in_out_scale = max(in1_out_scale, in2_out_scale)
if in1_scale == 0. and in1_zp == 0:
in_out_scale = in2_out_scale
elif in2_scale == 0. and in2_zp == 0:
in_out_scale = in1_out_scale
# Compute the number of bit shifts to normalize to [0.5, 1)
layer_sh = calculate_normalization_shift(abs(in_out_scale)) + 1
in_out_scale = in_out_scale / np.power(2.0, layer_sh)
in1_out_scale = in1_out_scale / np.power(2.0, layer_sh)
in2_out_scale = in2_out_scale / np.power(2.0, layer_sh)
bound = get_bound(layer_bits)
if round_op(in_out_scale * bound) >= bound:
layer_sh += 1
in1_out_scale /= 2
in2_out_scale /= 2
in1_scale = min(in1_out_scale * bound, bound - 1)
in2_scale = min(in2_out_scale * bound, bound - 1)
right_shift = -(layer_sh - (layer_bits - 1))
in2_zp = -in2_zp if is_subtract else in2_zp
zp_corr = layer_zp * np.power(2.0, -layer_sh) - \
(in1_zp * in1_out_scale + in2_zp * in2_out_scale)
in_scale = max(in1_scale, in2_scale)
while in_scale > bound - 1:
in1_scale /= 2
in2_scale /= 2
zp_corr /= 2
right_shift -= 1
in_scale /= 2
lhs_scale = int(round_op(in1_scale))
rhs_scale = int(round_op(in2_scale))
zp_correction = int(round_op(zp_corr * bound))
return [lhs_scale, rhs_scale], zp_correction, right_shift
def _symmetric_quant_bits(scale: float, min_val: float, max_val: float) -> int:
"""
Find the number of bits of precision that would be required for
using the given quantization.
The quantization being analyzed is q = r * scale, where min_val <= r <= max_val.
:param scale: Quantization scale
:param min_val: Minimum value in the real number range
:param max_val: Maximum value in the real number range
:return: Integer bit width required to represent quantized values
"""
assert min_val <= 0 <= max_val
qmin = round(scale * min_val)
qmax = round(scale * max_val)
return max(significant_bits_signed(qmin), significant_bits_signed(qmax))
[docs]
def quantize_multiply(lhs_quant: Quantization, rhs_quant: Quantization, output_quant: Quantization,
allow_full_output_precision: bool) \
-> Tuple[int, BaseRequantization[np.ndarray], Quantization]:
"""
Quantize the multiply operator.
:param lhs_quant: Quantization of the first input of multiply
:param rhs_quant: Quantization of the second input of multiply
:param output_quant: Quantization of the output of multiply.
It may be ignored if allow_full_output_precision is True.
:param allow_full_output_precision: Whether 32-bit output is allowed. If True, then
this function may ignore output_quant and output a 32-bit quantization. If false,
then this function will quantize according to output_quant.
:return: Tuple of intrinsic shift amount, requantization to perform, and quantization of the output.
"""
assert output_quant.bits in (8, 16)
output_quant_type = np.int8 if output_quant.bits == 8 else np.int16
# Handle zero input tensors or zero output tensor
if lhs_quant.scale == 0 or rhs_quant.scale == 0 or output_quant.scale == 0:
requant = FractionalZeroRequantization(1, output_quant.zero_point,
create_and_verify_narrowing(0, RoundType.TOEVEN, output_quant_type))
return 0, requant, output_quant
# Find properties of the quantized product, (x - zp_x)*(y - zp_y).
product_scale = lhs_quant.scale * rhs_quant.scale
product_min = min(lhs_quant.min_val * rhs_quant.max_val, lhs_quant.max_val * rhs_quant.min_val)
product_max = max(lhs_quant.min_val * rhs_quant.min_val, lhs_quant.max_val * rhs_quant.max_val)
assert product_min <= 0 <= product_max
# Ensure no arithmetic saturation. If the integer product is larger than a 31-bit integer,
# use right-shift to make it smaller. Calculation is based on 31 bits to leave one
# extra bit for adding zp_correction.
product_bits = _symmetric_quant_bits(product_scale, product_min, product_max)
intrinsic_shift = max(0, product_bits - 31)
product_scale /= (1 << intrinsic_shift)
product_quant = Quantization(scale=product_scale, zero_point=0, bits=32, min_val=product_min, max_val=product_max)
# Recalculate bits with the adjusted scale
product_bits = _symmetric_quant_bits(product_scale, product_min, product_max)
# If the product has no more than 16 bits of precision, then requantization can be done
# entirely in 32-bit precision, so use FractionalZero requantization or omit the requantization.
# Otherwise, use TFLite requantization.
if product_bits <= 16:
if allow_full_output_precision:
# Omit the requantization. Use the 32-bit intermediate result as the output.
requant = FractionalZeroRequantization(1, 0,
create_and_verify_narrowing(0, RoundType.TOEVEN, np.int32))
output_quant = product_quant
else:
# Requantize using FractionalZeroRequantiation
sc_corr, zp_corr, shift = requantization(product_quant, output_quant)
requant = FractionalZeroRequantization(sc_corr, zp_corr,
create_and_verify_narrowing(shift, RoundType.TOEVEN,
output_quant_type))
else:
# Requantize using TFLiteRequantization
sc_corr, zp_corr, shift = requantization_tflite(product_quant, output_quant)
requant = TFLiteRequantization(sc_correction=sc_corr, zp_correction=zp_corr, shift=shift,
rounding=RoundType.TOEVEN, out_dtype=output_quant_type)
return intrinsic_shift, requant, output_quant
[docs]
def quantize_batch_matmul(lhs_quant: Quantization, rhs_quant: Quantization, output_quant: Quantization) \
-> Tuple[int, BaseRequantization[np.ndarray], Quantization]:
assert output_quant.bits in (8, 16)
output_quant_type = np.int8 if output_quant.bits == 8 else np.int16
# Handle zero input tensors or zero output tensor
if lhs_quant.scale == 0 or rhs_quant.scale == 0 or output_quant.scale == 0:
requant = FractionalZeroRequantization(1, output_quant.zero_point,
create_and_verify_narrowing(0, RoundType.TOEVEN, output_quant_type))
return 0, requant, output_quant
scale_bits = 15 if output_quant_type == np.int16 else 7
max_shift = 15 if output_quant_type == np.int16 else 23
product_scale = lhs_quant.scale * rhs_quant.scale
req_scale = output_quant.scale / product_scale
shift = int(-np.ceil(np.log2(req_scale)).astype(np.int32)) + scale_bits
sc_corr = int(np.round(req_scale * np.power(2.0, shift)).astype(np.int32))
intrinsic_shift = int(np.maximum(shift - max_shift, 0))
if output_quant_type == np.int8:
assert intrinsic_shift == 0, "Intrinsic shift should be 0 for int8."
shift -= intrinsic_shift
zp_corr = output_quant.zero_point << shift
requant = FractionalZeroRequantization(sc_corr, zp_corr,
create_and_verify_narrowing(shift, RoundType.TOEVEN,
output_quant_type))
return intrinsic_shift, requant, output_quant
[docs]
def quantize_udf(
input_quant: Quantization, output_quant: Quantization, input_type: type, output_type: type,
func: Callable[[np.ndarray], np.ndarray], invert_scales: bool = True,
) -> np.ndarray:
"""
Create a lookup table for a user-defined function.
:param input_quant: Quantization of the input.
:param output_quant: Quantization of the output.
:param input_type: Type of LUT input.
:param output_type: Type of LUT output.
:param func: Function to be approximated by the lookup table.
:param invert_scales: If true, the input scale factors are inverted.
:return: Lookup table representing func for the quantized input and output.
It is a numpy array of int8 or int16 values.
"""
if invert_scales:
input_scale = 1.0 / _override_zero_scale(input_quant.scale)
output_scale = 1.0 / _override_zero_scale(output_quant.scale)
else:
input_scale = input_quant.scale
output_scale = output_quant.scale
lut = ComputeLookupTable(
func, input_quant=(input_scale, input_quant.zero_point),
output_quant=(output_scale, output_quant.zero_point),
input_type=input_type, output_type=output_type
)
return lut.build()
def _simplify_quantized_clip(
attrs: attributes.ClipQuantAttrs, zero_point: int,
scalar_type: ScalarType, *, asymmetry: bool = False
) -> Union[attributes.ClipQuantAttrs, attributes.ReluQuantAttrs, None]:
"""
Try to simplify a clip activation function in the quantized domain. Return the attributes for the
simplified activation function (clip, relu, or nothing).
:param attrs: Quantized attributes of clip.
:param zero_point: Zero point in the quantized domain. Determines if clip can be converted to relu.
:param scalar_type: ScalarType used to initialize ReluAttrs. Has to be integer type.
:param asymmetry: Whether asymmetric quantization is used. This is only used for error checking.
:return: Simplified activation function.
"""
clip_min = attrs.a_min
clip_max = attrs.a_max
iinfo = np.iinfo(scalar_type.numpy_type())
representable_min = iinfo.min
representable_max = iinfo.max
# Clipping to the representable range is implicitly carried out by saturating arithmetic, and supersedes
# clipping to any looser range.
if clip_min <= representable_min and clip_max >= representable_max:
return None
# Asymmetric quantization should always be simplified.
if asymmetry:
sima_logger.sima_log_info("A clip operator was not eliminated with asymmetric quantization. "
"This may indicate precision loss.")
# Clipping to the positive range can be simplified to relu.
if clip_min == zero_point and clip_max >= representable_max:
return attributes.ReluQuantAttrs(attrs.shape, zero_point)
# Can't simplify clip.
return attrs
def _simplify_quantized_relu(attrs: attributes.ReluQuantAttrs, zero_point: int, scalar_type: ScalarType) \
-> Union[attributes.ReluQuantAttrs, None]:
"""
Try to simplify a relu activation function in the quantized domain. Return the attributes for the
simplified activation function (relu or nothing).
:param attrs: Quantized attributes of clip.
:param zero_point: Zero point in the quantized domain. Determines if clip can be converted to relu.
:param scalar_type: ScalarType used to initialize ReluAttrs. Has to be integer type.
:return: Simplified activation function.
"""
representable_min = np.iinfo(scalar_type.numpy_type()).min
# If zero point is the minimum value, then relu will be implicitly carried out by saturating arithmetic.
if zero_point == representable_min:
return None
# Can't simplify relu.
return attributes.ReluQuantAttrs(attrs.input_shape, zero_point)
[docs]
def quantize_clip_attrs(attrs: attributes.ClipAttrs, scalar_type: ScalarType, quant: Quantization) -> attributes.ClipQuantAttrs:
"""
Quantize the attributes of clip operator
Calculate the boundaries of the clip operator based on its quantization
parameters and data type.
Args:
attrs: Attributes of the clip operator
scalar_type: Scalar data type of the quantized clip operator
quant: Quantization parameters to apply to clip operator
Returns:
Attributes of the quantized clip operator containing boundary parameters
calculated for quantized operator.
"""
iinfo = np.iinfo(scalar_type.numpy_type())
representable_min = iinfo.min
representable_max = iinfo.max
quantized_clip_min = max(representable_min, round(attrs.a_min * quant.scale + quant.zero_point))
quantized_clip_max = min(representable_max, round(attrs.a_max * quant.scale + quant.zero_point))
return attributes.ClipQuantAttrs(quantized_clip_min, quantized_clip_max, attrs.shape, scalar_type)
[docs]
def quantize_activation(attrs: Union[attributes.ClipAttrs, attributes.ReluAttrs, None], quantization: Quantization,
scalar_type: ScalarType, *,
quant_config: Optional[QuantizationConfigs] = None) \
-> Union[attributes.ClipQuantAttrs, attributes.ReluQuantAttrs, None]:
"""
Quantize a simple activation function (clip, relu, or nothing) and simplify it if possible.
No requantization is introduced to these activation functions; the input and output quantization scales are
always the same. Quantization may simplify an activation function by taking advantage of the
clipping behavior of saturating arithmetic.
:param attrs: Attributes of the activation function to quantize
:param quantization: Quantization to apply to this activation function
:param scalar_type: Scalar data type that the activation function will be evaluated on
:param scalar_type: ScalarType used to initialize ReluAttrs. Has to be integer type.
:param quant_config: Parameters that were used to choose 'quantization'. Used for error checking.
:return: Attributes of the quantized activation function. It may be a different type than the input.
"""
assert scalar_is_integral(scalar_type)
if attrs is None:
return None
if quantization.scale == 0:
# The input of the activation operator is constant zeros. The activations do nothing.
return None
if isinstance(attrs, attributes.ClipAttrs):
clip_attrs = quantize_clip_attrs(attrs, scalar_type, quantization)
return _simplify_quantized_clip(clip_attrs, quantization.zero_point, scalar_type,
asymmetry=quant_config and quant_config.asymmetry.get())
else:
assert isinstance(attrs, attributes.ReluAttrs)
return _simplify_quantized_relu(attrs, quantization.zero_point, scalar_type)
[docs]
def requantize_activation(attrs: Union[attributes.ClipQuantAttrs, attributes.ReluQuantAttrs, None],
zero_point: int,
requantization: BaseRequantization[np.ndarray],
scalar_type: ScalarType) \
-> Union[attributes.ClipQuantAttrs, attributes.ReluQuantAttrs, None]:
"""
Requantize an activation function.
This represents transforming the expression requant(activ(x)), where the
activation is evaluated before requantization, to an equivalent expression
newactiv(requant(x)), where the new activation is evaluated after requantization.
The new activation could be simpler by taking advantage of integer saturation.
:param attrs: Activation function's attributes. This must be for a quantized activation.
:param zero_point: Original zero point of the activation function, before requantization.
Ignored if attrs is None.
:param requantization: Requantization to perform. The input type of the
requantization is assumed to be int16.
:param scalar_type: ScalarType used to initialize ReluAttrs. Has to be integer type.
:return: Transformed activation function's attributes (clip, relu, or nothing).
"""
assert scalar_is_integral(scalar_type)
if attrs is None:
return None
new_zero_point = int(ml_kernels.requantization.requantize(np.array(zero_point), requantization).item())
if isinstance(attrs, attributes.ClipQuantAttrs):
a_min = ml_kernels.requantization.requantize(np.array(attrs.a_min), requantization).item()
a_max = ml_kernels.requantization.requantize(np.array(attrs.a_max), requantization).item()
return _simplify_quantized_clip(attributes.ClipQuantAttrs(a_min, a_max, attrs.shape, scalar_type),
new_zero_point, scalar_type)
else:
assert isinstance(attrs, attributes.ReluQuantAttrs)
return _simplify_quantized_relu(attrs, new_zero_point, scalar_type)
def _convert_shift_to_multiplier(shift: int | np.ndarray) -> float:
"""
Convert a per-tensor shift to a multiplier
such that x >> shift == x * _convert_shift_to_multiplier(shift).
Raise an error if the shift is per-channel.
:param shift: Shift value
:return: Multiplier value
"""
assert isinstance(shift, int), "Only per-tensor shift is supported"
return 1.0 / (1 << shift)
[docs]
def requantize_quantization(quantization: Quantization, requant: BaseRequantization[np.ndarray]) -> Quantization:
"""
Get the quantization of the result of requantizing a tensor.
This would be the quantization at the output of a Requantize node, for
the given input and requantization.
:param quantization: Quantization of input tensor
:param requant: Requantization to perform
:return: Quantization of the result of applying requant to the input tensor
"""
iinfo = np.iinfo(requant.out_dtype)
# If input is a zero tensor, output must also be a zero tensor
if quantization.scale == 0:
if isinstance(requant, (TFLiteRequantization, FractionalZeroRequantization)):
assert requant.zp_correction == 0, "Cannot handle zero point correction on a zero tensor"
return Quantization(0.0, 0, iinfo.bits, 0.0, 0.0)
# Given the old quantization q = r * scale + zero_point and requantization q2 = requant(q),
# find the new quantization q2 = r * scale2 + zero_point2.
if isinstance(requant, FractionalZeroRequantization):
pow2_scale = _convert_shift_to_multiplier(requant.narrowing.shift)
new_scale = pow2_scale * quantization.scale * requant.sc_correction
new_zp = round(pow2_scale * (quantization.zero_point * requant.sc_correction + requant.zp_correction))
elif isinstance(requant, TFLiteRequantization):
pow2_scale = _convert_shift_to_multiplier(requant.shift)
new_scale = pow2_scale * quantization.scale * requant.sc_correction
new_zp = round(pow2_scale * quantization.zero_point * requant.sc_correction) + requant.zp_correction
elif isinstance(requant, ArithFoldedRequantization):
pow2_scale = _convert_shift_to_multiplier(requant.narrowing.shift)
new_scale = pow2_scale * quantization.scale
new_zp = round(pow2_scale * quantization.zero_point)
else:
raise TypeError("Unrecognized requantization type")
# Clip (min_val, max_val) to the representable range.
# This models the effect of saturating arithmetic in the requantization operator.
min_val = max(quantization.min_val, (iinfo.min - new_zp) / new_scale)
max_val = min(quantization.max_val, (iinfo.max - new_zp) / new_scale)
assert min_val < max_val
return Quantization(new_scale, new_zp, iinfo.bits, min_val, max_val)
[docs]
def quantize_prelu(layer_bits: int, alpha: Union[np.ndarray, float]) -> Tuple[int, int]:
"""
Quantized the PRelu alphas and return the quantized alphas and right shifts
:param layer_bits: Number of bits used for quantization
:param alpha: Union[np.ndarray, float]. alpha in float data type
return: Tuple[np.ndarray, np.ndarray]. Tuple of (quantized alpha, right shift)
"""
alpha, right_shift = compute_power_of_2_scale_and_shift(alpha, layer_bits, 8)
quantized_alpha = round_op(alpha)
# Cast to Quantized parameter data type
if quantized_alpha.size == 1:
quantized_alpha = QuantizedParam(quantized_alpha)
else:
quantized_alpha = quantized_alpha.astype(QuantizedParam)
return quantized_alpha, right_shift
[docs]
def quantize_reciprocal(input_qtype: attributes.QuantResultTensorType) -> attributes.AwesomeCalibAttrs:
"""
Quantize the reciprocal part of divide
:param input_qtype: quantization for rhs argument of divide.
:return: calibration attributes AwesomeCalibAttrs which are used in ReciprocalOp UDF.
"""
assert isinstance(input_qtype.quant, Quantization)
reciprocal_min_value: float = 1.0 / input_qtype.quant.min_val
reciprocal_max_value: float = 1.0 / input_qtype.quant.max_val
reciprocal_scale: float = 127.0 / max(abs(reciprocal_min_value), abs(reciprocal_max_value))
reciprocal_zp = 0
reciprocal_quantization = Quantization(reciprocal_scale, reciprocal_zp, 8,
reciprocal_min_value, reciprocal_max_value)
output_qtype = attributes.QuantResultTensorType(input_qtype.type, reciprocal_quantization)
inputs_qtype = {InputName('data'): TensorValue(input_qtype)}
calib_attrs = attributes.AwesomeCalibAttrs(quant=TensorValue(output_qtype), input_quant=inputs_qtype)
return calib_attrs
[docs]
def quantize_lrn(attrs: attributes.LRNAttrs, input_quant: Quantization, quant: Quantization) \
-> attributes.LRNQuantAttrs:
"""
Quantize LRN which is implemented based on quantized_local_response_normalization from ml_kernels repo:
out = lut(square_sum(x)) * x
where lut function is:
lambda x: (bias + alpha / size * x) ** (beta)
:param attrs: LRN attributes.
:param input_quant: Quantization of input data
:param quant: Layer quantization
:return: Tuple[List[int], List[int], List[int]]. A tuple of
(re-scaled input scales, corrected input zero points, right shifts)
"""
input_scale = input_quant.scale
input_zp = input_quant.zero_point
input_min = input_quant.min_val
input_max = input_quant.max_val
output_scale = quant.scale
output_zp = quant.zero_point
# Calculate scale and zero point for sum of input squares from input calibration parameters
sq_min = 0
sq_max = max(input_min * input_min, input_max * input_max) * attrs.size
sq_scale = compute_scale(asymmetry=True, layer_bits=quant.bits, min_val=sq_min, max_val=sq_max)
sq_zp = compute_zero_point(asymmetry=True, layer_bits=quant.bits, min_val=sq_min, max_val=sq_max)
# Calculate maximal integer value for sum of input squares
max_sq = max((-128 - input_zp) * (-128 - input_zp), (127 - input_zp) * (127 - input_zp)) * attrs.size
# LUT requantization parameters
# Calculate maximum number of bits for lut_scale_q to avoid overflow
lut_scale_bits = 31 - np.ceil(np.log2(max_sq)).astype(np.int32)
lut_scale = sq_scale / (input_scale * input_scale)
lut_sh = -calculate_normalization_shift(lut_scale, rounding=RoundType.UPWARD) + lut_scale_bits
lut_scale_q = int(np.round(lut_scale * np.power(2.0, lut_sh)))
# Calculate LUT
func = lambda x: (attrs.bias + attrs.alpha / attrs.size * x) ** (-attrs.beta)
lut_output_scale = max(func(sq_min), func(sq_max)) / np.iinfo(np.int8).max
lut = ComputeLookupTable(func, input_quant=(1.0 / sq_scale, sq_zp),
output_quant=(lut_output_scale, 0), input_type=np.int8, output_type=np.int8)
lookup_table = lut.build()
# Calculate requantization parameters for output = input * LUT
layer_scale = output_scale * lut.output_scale / input_scale
layer_sh = -calculate_normalization_shift(layer_scale, rounding=RoundType.UPWARD) + 15
# Use a maximum of 15 bits for layer_scale_q to avoid overflow
layer_scale_q = int(np.round(layer_scale * np.power(2.0, layer_sh)))
quant_attrs = attributes.LRNQuantAttrs(axis=attrs.axis, size=attrs.size, shape=attrs.shape, input_zp=input_zp,
lut_scale=lut_scale_q, lut_zp_corr=int(sq_zp << lut_sh), lut_sh=int(lut_sh),
output_scale=layer_scale_q, output_zp_corr=int(output_zp << layer_sh),
output_sh=int(layer_sh), lookup_table=lookup_table)
return quant_attrs
def _fractional_zero_requantization(out_scale: float, in_scale: float, out_max: int, out_zp: int,
scale_bits: Optional[int] = None) -> FractionalZeroRequantization:
"""
Calculate fractional zero re-quantization.
:param out_scale: Output quantization scale in SiMa's convention.
:param in_scale: Input quantization scale in SiMa's convention.
:param out_max: Theoretical worst-case maximum for the output value prior to re-quantization.
:param out_zp: Output quantization zero point.
:param scale_bits: Number of bits to use. Calculate if not provided.
:return: Fractional zero re-quantization.
"""
# Calculate number of bits to be used for scale constant in the way that saturation is avoided.
scale_bits = 30 - np.floor(np.log2(out_max)).astype(np.int32) if scale_bits is None else (
scale_bits)
scale = out_scale / in_scale
sh = -calculate_normalization_shift(scale, rounding=RoundType.UPWARD) + scale_bits
# Limit shift value
sh = min(sh, 23)
scale_q = np.round(scale * np.power(2.0, sh)).astype(np.int32)
if scale_q == 0:
raise ValueError("Invalid scale. Please try using different calibration and/or "
"quantization schemes.")
zp_corr = out_zp << sh
return FractionalZeroRequantization(
int(scale_q), int(zp_corr), create_and_verify_narrowing(int(sh), RoundType.UPWARD, np.int8))
def _get_softmax_int16_requant(out_scale: float, in_scale: float, out_zp: int) ->\
Tuple[int, FractionalZeroRequantization]:
"""
Get fractional zero re-quantization for int16 Softmax.
:param out_scale: Output quantization scale in SiMa's convention.
:param in_scale: Input quantization scale in SiMa's convention.
:param out_zp: Output quantization zero point.
:return: Pre-shift value and Fractional zero re-quantization.
"""
lut_scale_bits = 15
scale = out_scale / in_scale
lut_sh = -int(np.ceil(np.log2(scale))) + lut_scale_bits
lut_scale_q = int(np.round(scale * np.power(2.0, lut_sh)))
pre_sh = max(lut_sh - lut_scale_bits, 0)
lut_sh -= pre_sh
lut_zp_corr = (out_zp << lut_sh)
requant_lut = FractionalZeroRequantization(lut_scale_q, lut_zp_corr,
Narrowing(lut_sh, RoundType.TOEVEN, np.int16))
return pre_sh, requant_lut
[docs]
def quantize_softmax(attrs: attributes.SoftmaxAttrs, input_quant: Quantization, quant: Quantization,
intermediate_min_max: Dict[str, Tuple[float, float]],
enable_int16: bool) \
-> attributes.SoftmaxQuantAttrs:
"""
Quantize Softmax which is implemented based on softmax implementation from ml_kernels repo:
exp = lut_exp(x) # lut_exp(x) = exp(x)
exp_sum_rec = lut_rec(np.sum(exp)) # lut_rec(x) = 1/x
ofm = exp * exp_sum_rec
:param attrs: Softmax attributes.
:param input_quant: Quantization of input data
:param quant: Layer quantization
:param intermediate_min_max: Dict of intermediates min/max values.
:param enable_int16: Whether to use int8 or int16 quantization.
:return: Quantized Softmax attributes
"""
from ml_kernels.np_operators import compute_intermediate_lut
import ml_kernels.requantization as requant
axis = attrs.axis
output_scale = quant.scale
output_zp = quant.zero_point
dtype = np.int16 if enable_int16 else np.int8
# Create exp(x) lookup table.
exp_in_min = input_quant.min_val - input_quant.max_val
exp_in_max = 0.0
# Correct exp_in_min value so calculated exp_zp remains in valid range. This range is being
# calculated in the way that max_sum_q remains in 23 bit range (32 - 9(scale))
max_allowed_zp = int(np.ceil(127 - np.sqrt(2 ** 23)))
if enable_int16:
exp_in_min = exp_in_min
symmetric = True
else:
max_allowed_min = _get_max_allowed_min(exp_in_max, max_allowed_zp, lambda x: float(np.exp(x)),
lambda x: float(np.log(x)))
exp_in_min = min(exp_in_min, max_allowed_min)
symmetric = False
lut_exp, exp_zp, exp_scale = compute_intermediate_lut(lambda x: float(np.exp(x)), exp_in_min, exp_in_max,
dtype=dtype, symmetric=symmetric)
# Calculate scale and zero point for rec input.
n = attrs.input_shape[axis]
if enable_int16:
sum_exp = intermediate_min_max['sum_exp']
sum_exp_min, sum_exp_max = min(sum_exp), min(4 * max(sum_exp), float(n))
rec_in_min, rec_in_max = sum_exp_max / 255.0, sum_exp_max
else:
sum_exp_min, sum_exp_max = intermediate_min_max['sum_exp']
rec_in_min, rec_in_max = max(sum_exp_max / 255.0, sum_exp_min), sum_exp_max
# Correct rec_in_min value so calculated rec_zp remains in valid range. This range is being
# calculated in the way that max_sum_q remains in 23 bit range (32 - 9(scale))
max_allowed_rec_in_min = _get_max_allowed_min(rec_in_max, max_allowed_zp, lambda x: float(np.reciprocal(x)),
lambda x: float(np.reciprocal(x)))
rec_in_min = min(rec_in_min, max_allowed_rec_in_min)
rec_in_scale = compute_scale(asymmetry=True, layer_bits=quant.bits, min_val=rec_in_min,
max_val=rec_in_max)
# Allow zp to be out of range.
rec_in_zp = -(int(np.round(rec_in_min * rec_in_scale)) + np.iinfo(dtype).max + 1)
# Calculate reciprocal(x) lookup table.
lut_rec, rec_zp, rec_scale = compute_intermediate_lut(lambda x: float(np.reciprocal(x)), rec_in_min, rec_in_max,
func_defined_at_zero=False, dtype=dtype)
# Calculate lut re-quant params.
# LUT input requnantization
if enable_int16:
lut_pre_sh, requant_lut = _get_softmax_int16_requant(out_scale=rec_in_scale, in_scale=exp_scale,
out_zp=rec_in_zp)
else:
max_sum_q = (127 - exp_zp) * n
requant_lut = _fractional_zero_requantization(out_scale=rec_in_scale, in_scale=exp_scale,
out_max=max_sum_q, out_zp=rec_in_zp)
lut_pre_sh = None
# Update zp_correction to match kernel implementation.
zp_correction = requant_lut.zp_correction - requant_lut.sc_correction * exp_zp * n
requant_lut = requant.FractionalZeroRequantization(requant_lut.sc_correction,
zp_correction,
requant_lut.narrowing)
# Calculate output re-quant params.
if enable_int16:
out_pre_sh, requant_output = _get_softmax_int16_requant(out_scale=output_scale, in_scale=exp_scale * rec_scale,
out_zp=output_zp)
else:
max_out_q = (127 - exp_zp) * (127 - rec_zp)
requant_output = _fractional_zero_requantization(out_scale=output_scale,
in_scale=exp_scale * rec_scale,
out_max=max_out_q, out_zp=output_zp)
out_pre_sh = None
quant_attrs = attributes.SoftmaxQuantAttrs(axis=axis,
input_shape=attrs.input_shape,
exp_zp=exp_zp,
rec_zp=rec_zp,
requant_lut=requant_lut,
requant_output=requant_output,
lookup_table_exp=lut_exp,
lookup_table_rec=lut_rec,
enable_int16=enable_int16,
lut_input_pre_shift=lut_pre_sh,
output_pre_shift=out_pre_sh)
return quant_attrs
def _get_max_allowed_min(in_max: float, zp_limit: int, fun: Callable[[float], float],
inv_fun: Callable[[float], float]) -> float:
"""
Calculate max allowed min value which will provide zp smaller than required limit on LUT
creation by compute_intermediate_lut function.
"""
assert zp_limit < -128
lut_max = fun(in_max)
max_allowed_min = lut_max / (1 - 255.0 / (zp_limit + 128))
if inv_fun(max_allowed_min) > in_max:
max_allowed_min = lut_max * (1 - 255.0 / (zp_limit + 128))
return inv_fun(max_allowed_min)
def _compute_scale_and_zp(min_val: float, max_val: float, bits: int = 8) -> Tuple[float, int]:
"""
Helper function for calculating scale and zero point,
it is used in Layer and Instance Normalization operators quantization.
Zero point is allowed to be out of -128, 127 range as it will yield better accuracy,
and quantized value will later be clipped to -128, 127 range.
"""
scale = compute_scale(asymmetry=True, layer_bits=bits, min_val=min_val, max_val=max_val)
# Allow zp to be out of -128, 127 range.
zp = -(np.round(min_val * scale) + 128).astype(np.int32)
return scale, zp
[docs]
def quantize_layer_norm(attrs: attributes.LayerNormAttrs,
input_quant: Quantization,
quant: Quantization,
intermediate_min_max: dict[str, tuple[float, float]]) \
-> attributes.LayerNormQuantAttrs:
"""
Quantize LayerNorm which is implemented based on layer norm implementation from ml_kernels repo:
LayerNorm(input, axis, epsilon) = (input - m) / Sqrt(var + epsilon), where
m = ReduceMean(input, axis, keepdims=True),
var = ReduceMean((input - m) ** 2, axis, keepdims=True).
Use LUT for reciprocal of the sqrt function.
:param attrs: LayerNormAttrs attributes.
:param input_quant: Quantization of input data.
:param quant: Layer quantization.
:param intermediate_min_max: Dict of intermediates min/max values.
:return: Quantized LayerNormAttrs attributes.
"""
from ml_kernels.np_operators import compute_intermediate_lut
axis = attrs.axis
input_scale = input_quant.scale
output_scale = quant.scale
output_zp = quant.zero_point
# Calculate mean re-quant params.
n = attrs.input_shape[axis]
mean_max = n * 127
requant_mean = _fractional_zero_requantization(out_scale=1, in_scale=n, out_max=mean_max,
out_zp=0)
# Create rsqrt lookup table.
lut_fun: Callable[[float], float] = lambda x: 1 / float(np.sqrt(x + attrs.epsilon))
var_min, var_max = intermediate_min_max['var']
lut_in_min, lut_in_max = max(var_max / 128, var_min), var_max
# Correct min value so calculated rsqrt_zp remains in valid range. This range is being
# calculated in the way that max_outq remains in 23 bit range (32 - 9(scale))
max_allowed_exp_zp = int(np.ceil(127 - (2**23)/255))
inv_lut_fun: Callable[[float], float] = lambda x: (1 / x) ** 2 - attrs.epsilon
max_allowed_min = _get_max_allowed_min(lut_in_max, max_allowed_exp_zp, lut_fun, inv_lut_fun)
lut_in_min = min(lut_in_min, max_allowed_min)
lut_rsqrt, rsqrt_zp, rsqrt_scale = compute_intermediate_lut(lut_fun, lut_in_min, lut_in_max)
# Calculate lut input re-quant params.
lut_in_scale, lut_in_zp = _compute_scale_and_zp(lut_in_min, lut_in_max, quant.bits)
max_lut_in = 127 * n
requant_lut_in = _fractional_zero_requantization(out_scale=lut_in_scale,
in_scale=input_scale * input_scale * n,
out_max=max_lut_in, out_zp=lut_in_zp,
scale_bits=7)
# Calculate output re-quant params.
max_outq = 255 * (127 - rsqrt_zp)
requant_output = _fractional_zero_requantization(out_scale=output_scale,
in_scale=input_scale * rsqrt_scale,
out_max=max_outq, out_zp=output_zp)
return attributes.LayerNormQuantAttrs(axis=attrs.axis,
input_shape=attrs.input_shape,
zp_rsqrt=int(rsqrt_zp),
lookup_table_rsqrt=lut_rsqrt,
requant_mean=requant_mean,
requant_lut_input=requant_lut_in,
requant_output=requant_output)
[docs]
def quantize_instance_norm(attrs: attributes.InstanceNormAttrs,
input_quant: Quantization,
mean_quant: Quantization,
variance_quant: Quantization,
quant: Quantization):
"""
Quantize Instance Normalization operator: (input - mean) / sqrt(variance + epsilon).
Args:
attrs: Instance Normalization attributes.
input_quant: Quantization of the input data.
mean_quant: Quantization of the mean input data.
variance_quant: Quantization of the variance input data.
quant: Layer quantization.
Returns:
Quantized Instance Normalization attributes.
"""
from ml_kernels.np_operators import compute_intermediate_lut
input_scale = max(input_quant.scale, mean_quant.scale)
output_scale = quant.scale
output_zp = quant.zero_point
# Create rsqrt lookup table.
lut_fun: Callable[[float], float] = lambda x: 1 / float(np.sqrt(x + attrs.epsilon))
var_min, var_max = variance_quant.min_val, variance_quant.max_val
lut_in_min, lut_in_max = max(var_max / 128, var_min), var_max
# Correct min value so calculated rsqrt_zp remains in valid range. This range is being
# calculated in the way that max_outq remains in 23 bit range (32 - 9(scale))
max_allowed_exp_zp = int(np.ceil(127 - (2 ** 23) / 255))
inv_lut_fun: Callable[[float], float] = lambda x: (1 / x) ** 2 - attrs.epsilon
max_allowed_min = _get_max_allowed_min(lut_in_max, max_allowed_exp_zp, lut_fun, inv_lut_fun)
lut_in_min = min(lut_in_min, max_allowed_min)
lut_rsqrt, rsqrt_zp, rsqrt_scale = compute_intermediate_lut(lut_fun, lut_in_min, lut_in_max)
# Calculate output re-quant params.
max_outq = 255 * (127 - rsqrt_zp)
requant_output = _fractional_zero_requantization(out_scale=output_scale,
in_scale=input_scale * rsqrt_scale,
out_max=max_outq, out_zp=output_zp)
return attributes.InstanceNormQuantAttrs(attrs=attrs, lut_rsqrt=lut_rsqrt, zp_rsqrt=rsqrt_zp,
requant_out=requant_output)
[docs]
def quantize_rms_norm(attrs: attributes.RMSNormAttrs, input_quant: Quantization, quant: Quantization,
intermediate_min_max: Dict[str, Tuple[float, float]],
enable_lut_int16: bool) -> attributes.RMSNormQuantAttrs:
"""
Quantize RMS Normalization which is implemented based on rms norm implementation from ml_kernels repo:
RMSNorm(x, axis, epsilon) = x / Sqrt(ReduceMean(x ** 2, axis, keepdims=True) + epsilon)
Use LUT for reciprocal of the sqrt function.
:param attrs: RMSNorm attributes.
:param input_quant: Quantization of input data.
:param quant: Layer quantization.
:param intermediate_min_max: Dict of intermediates min/max values.
:param enable_lut_int16: If True, quantize LUT to int16 otherwise to int8.
:return: Quantized RMSNorm attributes.
"""
from ml_kernels.np_operators import compute_intermediate_lut
# Define number of bits and data types
lut_bits = 16 if enable_lut_int16 else 8
lut_dtype = np.int16 if enable_lut_int16 else np.int8
lut_out_bits = 16 if enable_lut_int16 else 8
max_lut_sh = 31 - lut_bits
max_out_sh = 31 - lut_out_bits
# Input and output quantization parameters
input_scale = input_quant.scale
out_scale = quant.scale
y_zero_point = quant.zero_point
# Reciprocal of the sqrt part of the rms norm function is approximated using LUT
# Calculate LUT input range
sq_min, sq_max = intermediate_min_max['reduce_mean']
input_min = input_quant.min_val
input_max = input_quant.max_val
x_sq_max = np.maximum(input_min * input_min, input_max * input_max)
# The constants here were found experimentally while searching for the best range for the LUT
# so that loss of information is minimized
n = attrs.input_shape[-1]
# k is used to multiply sq_max so that max input is moved towards the middle of the lut where the accuracy is better
# 1/sqrt(x) is almost flattened near the maximum so the accuracy is bad there
# for the int8, because of the small number of bits that doesn't improve accuracy, k is set to 1
k = 1 if lut_bits == 8 else 4
sq_max = np.minimum(k * sq_max, x_sq_max * n)
# sq_max / 255 for lut_in_min is used so that min is not close to 0 since 1/sqrt(min) will be huge
# if sq_min is bigger than sq_max / 255 than sq_min is used since it's further from 0
# on the other hand bigger sq_min means bigger lut_zp and for int16 that would mean zp that's outside int16 range
# and that's why for int16 is always used sq_max / 255
if lut_bits == 8:
lut_in_min, lut_in_max = max(sq_max / 255, sq_min), sq_max
else:
lut_in_min, lut_in_max = sq_max / 255, sq_max
# Calculate LUT quantization parameters
lut_in_scale = compute_scale(asymmetry=True, layer_bits=lut_bits, min_val=lut_in_min, max_val=lut_in_max)
lut_in_zp = -(np.round(lut_in_min * lut_in_scale) + (2 ** (lut_bits - 1))).astype(np.int32)
# Calculate LUT re-quantization parameters
lut_in_scale_bits = 7 if lut_bits == 8 else 15
lut_in_req_scale = lut_in_scale / (input_scale * input_scale * n)
sh = np.ceil(np.log2(lut_in_req_scale)).astype(np.int32)
lut_in_sh = lut_in_scale_bits - sh
lut_in_scale_q = int(np.round(lut_in_req_scale * np.power(2.0, lut_in_sh)).astype(np.int32))
lut_in_pre_sh = np.maximum(lut_in_sh - max_lut_sh, 0)
lut_in_sh -= lut_in_pre_sh
req_lut_in = TFLiteRequantization(lut_in_scale_q, int(lut_in_sh), RoundType.TOEVEN, int(lut_in_zp), lut_dtype)
# Calculate LUT
lut_fun = lambda x: 1 / np.sqrt(x + attrs.epsilon)
lut_rsqrt, rsqrt_zp, rsqrt_scale = compute_intermediate_lut(lut_fun, lut_in_min, lut_in_max, dtype=lut_dtype)
# Calculate output re-quantization parameters
output_scale = out_scale / (input_scale * rsqrt_scale)
sh = np.ceil(np.log2(output_scale)).astype(np.int32)
out_scale_bits = 10 if lut_bits == 8 else 15
out_sh = out_scale_bits - sh
out_scale_q = int(np.round(output_scale * np.power(2.0, out_sh)).astype(np.int32))
out_pre_sh = np.maximum(out_sh - max_out_sh, 0)
out_sh -= out_pre_sh
req_output = TFLiteRequantization(out_scale_q, int(out_sh), RoundType.TOEVEN, int(y_zero_point), np.int8)
return attributes.RMSNormQuantAttrs(input_shape=attrs.input_shape, zp_ifm=input_quant.zero_point,
lookup_table_rsqrt=lut_rsqrt, zp_rsqrt=rsqrt_zp, requant_lut_input=req_lut_in,
requant_output=req_output, lut_input_pre_shift=int(lut_in_pre_sh),
output_pre_shift=int(out_pre_sh), enable_lut_int16=enable_lut_int16)
[docs]
def quantization_data_value_to_output_list(quantization: DataValue[Quantization]) \
-> Tuple[List[float], List[int], List[int], List[int], List[int]]:
"""
Convert a Data value of Quantization object(s) to lists of quantization-related values.
This is used for interfacing to code that stores quantization information in five separate lists.
:param: quantization: DataValue of Quantization object(s) to convert to quantization parameters Tuple.
:return: Lists of scales, zero points, bits, minimum and maximum values.
"""
scales = []
zero_points = []
bitss = []
min_vals = []
max_vals = []
# Flatten the contents of quantizations into lists
def extract_from(q: DataValue[Quantization]):
if isinstance(q, TensorValue):
scales.append(q.value.scale)
zero_points.append(q.value.zero_point)
bitss.append(q.value.bits)
min_vals.append(q.value.min_val)
max_vals.append(q.value.max_val)
elif isinstance(q, TupleValue):
for element in q.elements:
extract_from(element)
else:
raise TypeError("Invalid type for DataValue")
extract_from(quantization)
return scales, zero_points, bitss, min_vals, max_vals
def _fix_narrowing(narrowing: Narrowing) -> Narrowing:
"""
Change the data type of the right_shift array, if it is present, to uint8.
"""
shift = narrowing.shift
if isinstance(shift, np.ndarray):
return dataclasses.replace(narrowing, shift=shift.astype(np.uint8))
# else
return narrowing
[docs]
def fix_requantization(requantization: BaseRequantization[np.ndarray]) -> BaseRequantization[np.ndarray]:
"""
Change the data type of the right_shift array, if it is present, to uint8.
"""
if isinstance(requantization, ml_kernels.requantization.TFLiteRequantization):
shift = requantization.shift
if isinstance(shift, np.ndarray):
return dataclasses.replace(requantization, shift=shift.astype(np.uint8))
# else
return requantization
elif isinstance(requantization, ml_kernels.requantization.ArithFoldedRequantization):
return dataclasses.replace(requantization, narrowing=_fix_narrowing(requantization.narrowing))
elif isinstance(requantization, ml_kernels.requantization.FractionalZeroRequantization):
return dataclasses.replace(requantization, narrowing=_fix_narrowing(requantization.narrowing))
raise TypeError("Unrecognized requantization type")
[docs]
def create_requantization_from_cast(cast: RequantCast) -> BaseRequantization[np.ndarray]:
"""
Get the Requantization that implements the given cast.
:param cast: Cast to perform
:return: Requantization
"""
rq_input = cast.get_input_quantization()
rq_output = cast.get_output_quantization()
out_dtype = np.int8 if rq_output.bits == 8 else np.int16
if cast.requant_method == RequantMethod.arith_folded:
shift = power_of_2_requantization(rq_input, rq_output)
if shift >= 0:
requant = narrowing_requantization(shift, RoundType.TOEVEN, out_dtype)
else:
# Scale factor is greater than 1, equivalent to a left shift. Use a multiplication.
assert shift >= -8, "Scale factor is too large for requantization"
requant = TFLiteRequantization(sc_correction=1 << -shift, shift=0, rounding=RoundType.TRUNC,
zp_correction=0, out_dtype=out_dtype)
else:
assert cast.requant_method in (RequantMethod.fractional_zero, RequantMethod.scaled_fz)
sc_corr, zp_corr, shift = requantization(rq_input, rq_output, 32)
if cast.requant_method == RequantMethod.scaled_fz:
# The input and output quantization should have been chosen to ensure zp_corr == 0.
# Allow error up to +/- (0.5 * 2**shift), which would be +/- 0.5 in the output.
zp_corr_bound = 1 << max(shift - 1, 0)
assert abs(zp_corr) <= zp_corr_bound, \
"Scale-only requantization is not suitable for producing the wanted quantization"
zp_corr = 0
requant = FractionalZeroRequantization(sc_corr, zp_corr, Narrowing(shift, RoundType.TOEVEN, out_dtype))
return requant