#########################################################
# Copyright (C) 2021 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Joey Chou
#########################################################
from dataclasses import dataclass, field, replace
import numpy as np
import operator
from typing import List, Dict, Any, Tuple, Optional, Union, TypeVar, Generic, Sequence
from ml_kernels.c_function_call_helpers import OperatorFunction
from ml_kernels.math_helpers import RoundType
from ml_kernels.requantization import BaseRequantization, TFLiteRequantization
from afe._tvm._defines import TVMGraphModule
from afe._tvm._runtime import generate_graph_module
from afe.apis.defines import (
ChromaSampling, ColorConversion, ColorSpaceStandard, ResizeDepositLocation, ResizeMethod
)
from afe.backends import Backend
from afe.ir.custom_operation.utils import parse_custom_op_attrs_to_dict
from afe.ir.defines import (
AwesomeDilation3D, AwesomePad3D, AwesomeStrides3D, NoneType, InputName, AwesomePad,
AwesomeStrides, AwesomeDilation, AwesomePoolSize, DataValue, TensorValue, Quantization,
reduce_data_value, map_data_value, TupleValue, RequantMethod
)
from afe.ir.node_observer import NodeObserver
from afe.ir.tensor_type import (
TensorType, ScalarType, NodeType, set_tensor_type_batch_size, set_node_type_batch_size, scalar_is_floating
)
from afe.ir.utils import set_shape_batch_size, is_mla_supported_einsum_equation
[docs]
DEFAULT_PER_CHANNEL = False
[docs]
class ObservedDistribution:
"""
A value distribution that was observed during calibration. This value
distribution can be used to decide how to quantize a tensor.
"""
_observer: NodeObserver
def __init__(self, observer: NodeObserver):
"""
Initialize an instance.
:param observer: Observer holding statistics about a data distribution.
The observer should not be modified while this class is in use.
"""
self._observer = observer
[docs]
def calculate_quantization(self, qrange: Tuple[int, int]) -> DataValue[Quantization]:
"""
Choose a quantization to use for representing the observed value distribution
using the given integer range.
:param qrange: Integer range to quantize for. The range must be representable by
an 8-bit or 16-bit signed integer.
:return: Selected quantizations.
"""
return self._observer.calculate_quantization(qrange)
[docs]
def get_min_max(self) -> Tuple[float, float]:
"""
Get range (min and max) of observed values.
Min-max range does not need to include zero.
:return: Tuple of min and max values.
"""
return self._observer.min_max()
[docs]
def get_mean(self):
return self._observer.get_mean()
@dataclass(frozen=True)
[docs]
class QuantResultTensorType:
"""
The result of running the quantization transformation on a tensor.
It has the tensor's type and quantization, as they are after the transformation.
Only tensors that are quantized by the transformation have a quantization.
:param type: The tensor's type after transformation. It has the same shape
as before the transformation. Its scalar type may be different.
:param quant: The tensor's quantization, if it was quantized by the
quantization transformation. None otherwise. Floating-point tensors do
not have a quantization. Integer tensors do not have a quantization if
they were already integer before the quantization transformation.
:param requant_method: The method that should be used for requantizing this
tensor's value when requantization is needed. This field must
be None iff quant is None.
"""
[docs]
quant: Optional[Quantization]
[docs]
requant_method: Optional[RequantMethod]
def __post_init__(self):
assert (self.quant is not None) == (self.requant_method is not None)
@staticmethod
[docs]
def from_type(type: TensorType) -> "QuantResultTensorType":
"""
Make a QuantResultTensorType that only has type information.
"""
return QuantResultTensorType(type, None, None)
@staticmethod
[docs]
def from_quant(quant: Optional[Quantization]) -> "QuantResultTensorType":
"""
Make a QuantResultTensorType from a Quantization using dummy type information.
This is a temporary method that should be removed when support for
QuantResultTensorType is finished.
"""
return QuantResultTensorType(dummy_quant_result_tensor_type.type, quant, RequantMethod.fractional_zero)
[docs]
def get_quant_result_scale_with_dummy(t: QuantResultTensorType) -> Quantization:
"""
Get the quantization scale; if there is none, return a dummy value.
The dummy value is a temporary solution that should be removed when support
for QuantResultTensorType is finished.
"""
if t.quant is None:
return dummy_quant_result_tensor_type.quant
return t.quant
[docs]
def get_data_value_quant_result_scale_with_dummy(t: DataValue[QuantResultTensorType]) \
-> DataValue[Quantization]:
"""Run get_quant_result_scale_with_dummy on the contents of a DataValue."""
return map_data_value(get_quant_result_scale_with_dummy, t)
[docs]
def get_dict_quant_result_scale_with_dummy(t: Dict[InputName, DataValue[QuantResultTensorType]]) \
-> Dict[InputName, DataValue[Quantization]]:
"""Run get_quant_result_scale_with_dummy on the contents of a dict of DataValue."""
return {k: get_data_value_quant_result_scale_with_dummy(t) for k, t in t.items()}
[docs]
def update_quant_result_quantization(t: DataValue[QuantResultTensorType], new_quant: DataValue[Quantization]) \
-> DataValue[QuantResultTensorType]:
"""
Insert the given quantization into t, replacing existing quantization values in t.
:param t: Quantization result type to modify
:param new_quant: Quantization values
:return: A copy of t with values from new_type inserted
"""
if isinstance(t, TensorValue):
assert isinstance(new_quant, TensorValue)
return TensorValue(replace(t.value, quant=new_quant.value))
else:
assert isinstance(t, TupleValue)
assert isinstance(new_quant, TupleValue)
assert len(t.elements) == len(new_quant.elements)
new_elements = [update_quant_result_quantization(t_elem, new_quant_elem)
for t_elem, new_quant_elem in zip(t.elements, new_quant.elements)]
return TupleValue(new_elements)
[docs]
def update_quant_result_type(t: DataValue[QuantResultTensorType], new_type: DataValue[TensorType]) \
-> DataValue[QuantResultTensorType]:
"""
Ensure that t's type matches new_type by replacing dummy types with data from t and
checking non-dummy types.
This function's purpose is to save the type into t while developing QuantResultTensorType,
then to check consistency after it is developed.
:param t: Quantization result type, which may contain dummy types
:param new_type: Type that should be the same as the type in t
:return: A copy of t with values from new_type inserted to replace any dummy types
"""
if isinstance(t, TensorValue):
assert isinstance(new_type, TensorValue)
if is_dummy_type(t.value.type):
return TensorValue(replace(t.value, type=new_type.value))
else:
assert t.value.type == new_type.value, f"Expected {new_type.value}, got {t.value.type}"
return t
else:
assert isinstance(t, TupleValue)
assert isinstance(new_type, TupleValue)
assert len(t.elements) == len(new_type.elements)
new_elements = [update_quant_result_type(t_elem, new_type_elem)
for t_elem, new_type_elem in zip(t.elements, new_type.elements)]
return TupleValue(new_elements)
[docs]
def set_quant_result_type_batch_size(t: DataValue[QuantResultTensorType], batch_size: int) \
-> DataValue[QuantResultTensorType]:
"""
Modifies DataValue of QuantResultTensorType with given batch size.
:param t: DataValue[QuantResultTensorType]. Value to be modified.
:param batch_size: int. Batch size value to be used in constructing new QuantResultTensorType DataValue.
:return: DataValue[QuantResultTensorType]. QuantResultTensorType with its type's shape field modified
to use batch_size.
"""
return map_data_value(lambda x: replace(x, type=set_tensor_type_batch_size(x.type, batch_size)), t)
# Dummy value used to initialize data structures. It should be replaced by the
# real value before the data gets used.
[docs]
dummy_quant_result_tensor_type = \
QuantResultTensorType(TensorType(ScalarType.int8, (1, 17, 17, 16)), Quantization(), RequantMethod.fractional_zero)
# Check if the given type is the dummy type. This is used temporarily while
# implementing quantization type support, to detect when the implementation is not
# assigning a type.
[docs]
def is_dummy_type(t: TensorType) -> bool:
return t == dummy_quant_result_tensor_type.type
########################
# Calibration attributes
########################
@dataclass
[docs]
class AwesomeCalibAttrs:
"""
Calibration attributes
:param observer: Observer used during calibration of the node. If the node does not use
calibration data for calculation of quantization parameters, observer will not be
created and its value will be None.
:param intermediate_observers: Observers used for quantization of intermediate results.
:param quant: Quantization scale of the output. It is assigned during quantization.
:param input_quant: Quantization scale of each input. During quantization, it is first
assigned the type and quantization scale that were determined at the nodes that
compute the inputs. Then, when the node is quantized, it is assigned the types and
quantization scales of inputs that the node accepts.
"""
[docs]
observer: Optional[NodeObserver] = None
[docs]
quant: DataValue[QuantResultTensorType] = TensorValue(dummy_quant_result_tensor_type)
[docs]
precomputed_quant: Optional[Quantization] = None
def __post_init__(self):
def _is_quantization_data_value(dv: DataValue) -> bool:
return reduce_data_value(operator.and_,
map_data_value(lambda x: isinstance(x, QuantResultTensorType), dv),
True)
assert isinstance(self.quant, DataValue)
assert _is_quantization_data_value(self.quant)
assert isinstance(self.input_quant, Dict)
for k, v in self.input_quant.items():
assert isinstance(k, str)
assert isinstance(v, DataValue)
assert _is_quantization_data_value(v)
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
for k, v in self.input_quant.items():
self.input_quant[k] = set_quant_result_type_batch_size(v, batch_size)
self.quant = set_quant_result_type_batch_size(self.quant, batch_size)
#########################
# Quantization attributes
#########################
[docs]
class AwesomeQuantAttrBase:
"""
Base class of quantized operator attributes. This class is used for instance
checking only.
"""
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
Should be implemented separately inside each class that inherits from AwesomeQuantAttrBase.
"""
pass
@dataclass
[docs]
class PlaceholderQuantAttrs(AwesomeQuantAttrBase):
"""
Properties of a quantized placeholder.
:param type: Type of the placeholder's output.
:param quantization: Quantization of the placeholder, if it was quantized
by the Quantize compiler pass.
"""
[docs]
quantization: Optional[Quantization]
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.type = set_tensor_type_batch_size(self.type, batch_size)
@dataclass
[docs]
class ConstantQuantAttrs(AwesomeQuantAttrBase):
"""
:param quant_data: Quantized tensor value
"""
@dataclass
class ConcatQuantAttrs(AwesomeQuantAttrBase):
"""
Contains quantization attributes for concatenate quantization.
:param input_scale_corrections: Quantized scale correction for each inputs.
:param input_zp_corrections: Quantized zero point correction for each inputs.
:param right_shift: Number of bits in right shift during requantize at inference time.
:param layer_bits: Number of bits used for quantizing the tensor.
:param axis: The axis along which the tensors are concatenated.
:param node_zps: Zero points(s) of the quantized output tensors(s)
:param input_scales: Quantized scale for eash inputs.
:param node_scales: Using the max input_scales as the concatenate output scale of the quantized output tensors(s).
"""
[docs]
right_shifts: List[int] = field(default_factory=list)
layer_bits: List[int] = field(default_factory=lambda: [8])
[docs]
axis: Optional[int] = None
input_scales: Optional[List[Union[float, List[float]]]] = field(default_factory=list) # For Backend Model_Builder
node_scales: Optional[List[float]] = field(default_factory=list) # For Backend Model_Builder
node_zps: Optional[List[int]] = None # For graph_analyzer.
[docs]
rounding_type: RoundType = RoundType.TOEVEN
@dataclass
[docs]
class MultiplyQuantAttrs(AwesomeQuantAttrBase):
"""
param lhs_input_shape: Lhs input shape
param rhs_input_shape: Rhs input shape
:param input_int16: If True, the inputs have int16 type. If false, the inputs have int8 type.
:param intrinsic_shift: Right-shift to apply before requantization.
param requant Requantization parameters
param lhs_zero_point: Zero point of the left-hand side input.
param rhs_zero_point: Zero point of the right-hand side input.
param layer_bits: Number of bits used to quantize output tensor.
"""
[docs]
requant: BaseRequantization[np.ndarray]
[docs]
lhs_zero_point: int = 0
[docs]
rhs_zero_point: int = 0
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.lhs_input_shape = set_shape_batch_size(self.lhs_input_shape, batch_size)
self.rhs_input_shape = set_shape_batch_size(self.rhs_input_shape, batch_size)
@dataclass
[docs]
class LeakyReluQuantAttrs(AwesomeQuantAttrBase):
"""
The slope for quantized_intput < zero_point is (alpha >> right_shift)
"""
[docs]
rounding_type: RoundType
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_shape = set_shape_batch_size(self.input_shape, batch_size)
##################
# Basic attributes
##################
@dataclass
[docs]
class AwesomeAttributes:
"""
A class that stores attributes necessary for the execution of its associated AwesomeOperation.
Subclasses should include all additional attributes in their __init__ functions and call back to the
AwesomeAttributes __init__ function to include the default attributes
"""
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
Should be implemented separately inside each class that inherits from AwesomeAttributes.
"""
pass
@dataclass
[docs]
class PlaceholderAttrs(AwesomeAttributes):
"""
Properties of a placeholder.
"""
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.type = set_tensor_type_batch_size(self.type, batch_size)
@dataclass
[docs]
class ConstantAttrs(AwesomeAttributes):
"""
:param data: Tensor value before quantization
"""
# Tensor value before quantization pass. May be integer or floating-point.
@dataclass
[docs]
class MultiplyAttrs(AwesomeAttributes):
"""
Attributes of a multiply operator.
:param scalar_type: Type of input and output. Must be a floating-point type.
:param lhs_input_shape: Shape of first input.
:param rhs_input_shape: Shape of second input.
"""
[docs]
scalar_type: ScalarType
def __post_init__(self):
assert scalar_is_floating(self.scalar_type)
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.lhs_input_shape = set_shape_batch_size(self.lhs_input_shape, batch_size)
self.rhs_input_shape = set_shape_batch_size(self.lhs_input_shape, batch_size)
[docs]
def convolution_output_shape(conv_attrs: "ConvAttrs") -> tuple[int, ...]:
"""
Get the shape of a convolution's output tensor based on its attributes.
"""
spatial_shape = []
for n, k, (p_lo, p_hi), (q_lo, q_hi), s, d in zip(conv_attrs.input_spatial_shape, conv_attrs.kernel_size,
conv_attrs.padding, conv_attrs.output_padding,
conv_attrs.stride, conv_attrs.dilation):
if not conv_attrs.is_transposed:
o = ((n + p_lo + p_hi - d * (k - 1) - 1) // s) + 1
else:
o = (n - 1)*s - p_lo - p_hi + d * (k - 1) + q_hi + 1
spatial_shape.append(o)
return (conv_attrs.batch_size, *spatial_shape, conv_attrs.channels)
@dataclass
[docs]
class ConvAttrs(AwesomeAttributes):
"""
Attributes of a convolution operation.
The attributes describe a convolution with input and output activations
in NWC, NHWC, or NDHWC layout and weights in WIGO, HWIGO, or DHWIGO layout.
The dimension order for 1, 2, or 3 spatial dimensions respectively is W, HW, or DHW.
Args:
stride: Stride in each spatial dimension
dilation: Dilation in each spatial dimension
padding: Padding in each spatial dimension. The padding in each dimension is
a tuple holding the padding width at the beginning and end of the dimension.
output_padding: Padding of the output tensor in each spatial dimension for transposed
convolution. If it is not a transposed convolution, all padding values must be zero.
If it is a transposed convolution, the first element of the padding must be zero.
is_transposed: Whether it is a transposed convolution.
weight_shape: Shape of the weight tensor in spatial dimensions ++ IGO layout, for example HWIGO.
IGO is an abbreviation for "input channels, groups, output channels".
input_spatial_shape: Shape of the input tensor in spatial dimensions.
batch_size: Batch size.
input_type: Scalar type of the convolution's input tensor. This type is ignored
for quantized convolutions.
"""
[docs]
stride: tuple[int, ...]
[docs]
dilation: tuple[int, ...]
[docs]
padding: tuple[tuple[int, int], ...]
[docs]
output_padding: tuple[tuple[int, int], ...]
[docs]
weight_shape: tuple[int, ...]
def __post_init__(self):
spatial_dims = len(self.input_spatial_shape)
assert 1 <= spatial_dims <= 3, "Convolution must have 1, 2, or 3 spatial dimensions"
assert len(self.weight_shape) == spatial_dims + 3, \
"Dimensionality of input_spatial_shape and weight_shape is not consistent"
assert len(self.stride) == spatial_dims, f"Expected stride to have {spatial_dims} elements"
assert len(self.dilation) == spatial_dims, f"Expected dilation to have {spatial_dims} elements"
assert len(self.padding) == spatial_dims, f"Expected padding to have {spatial_dims} elements"
assert len(self.output_padding) == spatial_dims, f"Expected output_padding to have {spatial_dims} elements"
if self.is_transposed:
assert all(p[0] == 0 for p in self.output_padding), \
"First element of output padding must be zero for transposed convolution"
else:
assert all(p == (0, 0) for p in self.output_padding), \
"Output padding must be zero for non-transposed convolution"
@property
[docs]
def groups(self) -> int:
"""
Get the number of convolution groups.
"""
return self.weight_shape[-2]
@property
[docs]
def channels(self) -> int:
"""
Get the number of convolution output channels.
"""
# output_channels_per_group * groups
return self.weight_shape[-2] * self.weight_shape[-1]
@property
@property
[docs]
def kernel_size(self) -> tuple[int, ...]:
"""
Get the shape of the convolution kernel in the spatial dimensions.
"""
return self.weight_shape[:-3]
@property
[docs]
def num_spatial_dimensions(self) -> int:
"""
Get the number of spatial dimensions for this convolution.
"""
return len(self.input_spatial_shape)
@property
@property
[docs]
def output_shape(self) -> tuple[int, ...]:
"""
Get the shape of the convolution's output tensor in NWC, NHWC, or NDHWC layout.
"""
return convolution_output_shape(self)
@property
[docs]
def is_depthwise_one_channel(self) -> bool:
"""
Return true if this convolution is a depthwise convolution with
equal number of input and output channels.
"""
# True if weight shape has many groups, one input channel, one output channel
return self.weight_shape[-2] > 1 and self.weight_shape[-3] == 1 and self.weight_shape[-1] == 1
@dataclass
[docs]
class PoolAttrs(AwesomeAttributes):
"""
:param ceil_mode: Used to take ceil or floor when computing the output shape
:param out_layout: Layout of the output. This can be an empty str if layout is the same as data_layout.
:param layout: Uses the letters NHWC for BatchNumber, Height, Width, Channels
:param padding: ((pad_top, pad_bot), ...) along the dimensions of NHWC according to layout
:param pool_size: Size of pooling
:param strides: Strides
:param dilation: Dilation along the dimensions of NHWC according to data_layout
:param scalar_type: Data type of the input and output.
"""
[docs]
pool_size: AwesomePoolSize
[docs]
strides: AwesomeStrides
[docs]
dilation: AwesomeDilation
[docs]
scalar_type: ScalarType
def __post_init__(self):
# TODO: move the constraints to the checker functions
if np.product(self.dilation) > 1:
raise RuntimeError(f"We currently do not support dilation > 1 for Pool. Got dilation = {self.dilation}")
if self.ceil_mode:
raise RuntimeError("We currently do not support ceil_mode == True for Pool. "
"Make sure the SetCeilModeToFalseForNDPooling pass is enabled when importing the model")
assert scalar_is_floating(self.scalar_type)
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_shape = set_shape_batch_size(self.input_shape, batch_size)
@dataclass
[docs]
class MaxPoolAttrs(PoolAttrs):
pass
@dataclass
[docs]
class AvgPoolAttrs(PoolAttrs):
"""
:param count_include_pad: If true, include padding to compute the average.
"""
[docs]
count_include_pad: bool
@dataclass
[docs]
class VarianceAttrs(AwesomeAttributes):
"""
Attributes:
input_data_shape: Shape of the input tensor.
mean_shape: Shape of the mean input tensor.
scalar_type: Scalar type of the input and output.
axis: The axes to sum over when computing mean.
"""
[docs]
mean_shape: tuple[int, ...]
[docs]
scalar_type: ScalarType
@dataclass
[docs]
class AdaptiveAvgPool2DAttrs(AwesomeAttributes):
"""
:param output_size: tuple of int. optional Output height and width.
:param out_layout: Layout of the output. This can be an empty str if layout is the same as data_layout.
:param layout: Layout of the input.
"""
[docs]
output_size: Tuple[int, ...]
def __post_init__(self):
_error_msg1 = "Doesn't support AdaptiveAvgPool2DOp when output_size is None"
_error_msg2 = f"Only support AdaptiveAvgPool2DOp with output_size is 1x1, got {self.output_size}"
assert self.output_size is not None, _error_msg1
assert np.prod(self.output_size) == 1, _error_msg2
@dataclass
[docs]
class ReluAttrs(AwesomeAttributes):
"""
:param scalar_type: Type of input and output.
:param input_shape: Shape of input.
"""
[docs]
scalar_type: ScalarType
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_shape = set_shape_batch_size(self.input_shape, batch_size)
@dataclass
[docs]
class AddAttrs(AwesomeAttributes):
"""
Attributes of an add operator.
:param scalar_type: Type of input and output. Must be a floating-point type.
:param lhs_input_shape: Shape of first input.
:param rhs_input_shape: Shape of second input.
"""
[docs]
scalar_type: ScalarType
def __post_init__(self):
assert scalar_is_floating(self.scalar_type)
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.lhs_input_shape = set_shape_batch_size(self.lhs_input_shape, batch_size)
self.rhs_input_shape = set_shape_batch_size(self.lhs_input_shape, batch_size)
@dataclass
[docs]
class SubtractAttrs(AwesomeAttributes):
"""
Attributes of a subtract operator.
:param scalar_type: Type of input and output. Must be a floating-point type.
:param lhs_input_shape: Shape of first input.
:param rhs_input_shape: Shape of second input.
"""
[docs]
scalar_type: ScalarType
def __post_init__(self):
assert scalar_is_floating(self.scalar_type)
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.lhs_input_shape = set_shape_batch_size(self.lhs_input_shape, batch_size)
self.rhs_input_shape = set_shape_batch_size(self.rhs_input_shape, batch_size)
@dataclass
[docs]
class BiasAddAttrs(AwesomeAttributes):
"""
:param input_shape: The shape of the input activation tensor
:param axis: The axis to add the bias
"""
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_shape = set_shape_batch_size(self.input_shape, batch_size)
@dataclass
[docs]
class ConstantMultiplyAddAttrs(AwesomeAttributes):
"""
Attributes representing the computation (a*c + b*d) for scalar constants c and d.
"""
[docs]
scalar_type: ScalarType
[docs]
in1_const_attrs: ConstantAttrs
# Scalar value to multiply by the second operand. If None, it represents multiplication by 1.
[docs]
in2_const_attrs: Optional[ConstantAttrs]
def __init__(self, scalar_type: ScalarType,
lhs_input_shape: Tuple[int, ...],
rhs_input_shape: Tuple[int, ...],
in1_const_attrs: ConstantAttrs,
in2_const_attrs: Optional[ConstantAttrs] = None):
# Only scalr constant is supported as this point
# TODO: Add support for vector constants
self.scalar_type = scalar_type
self.lhs_input_shape = lhs_input_shape
self.rhs_input_shape = rhs_input_shape
assert len(in1_const_attrs.data.shape) == 1
self.in1_const_attrs = in1_const_attrs
if in2_const_attrs is not None:
assert len(in2_const_attrs.data.shape) == 1
self.in2_const_attrs = in2_const_attrs
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.lhs_input_shape = set_shape_batch_size(self.lhs_input_shape, batch_size)
self.rhs_input_shape = set_shape_batch_size(self.rhs_input_shape, batch_size)
@dataclass
[docs]
class MeanAttrs(AwesomeAttributes):
"""
:param axis: Axis or axes along which a mean operation is performed.
:param exclude: If `exclude` is true, we use the axes that are NOT in the axis field
:param keepdims: If set to true the axes reduces are left with a size of 1
"""
def __post_init__(self):
assert isinstance(self.axis, List), \
f"Type mismatch for axis field. Expected List[int], got {type(self.axis)}."
for a in self.axis:
assert isinstance(a, int), \
f"Type mismatch for axis field. Expected List[int], got List[{type(a)}]."
assert isinstance(self.exclude, bool), \
f"Type mismatch for exclude field. Expected bool, got {type(self.exclude)}."
assert isinstance(self.keepdims, bool), \
f"Type mismatch for keepdims field. Expected bool, got {type(self.keepdims)}."
assert isinstance(self.shape, Tuple), \
f"Type mismatch for shape field. Expected Tuple[int, ...], got {type(self.shape)}."
for s in self.shape:
assert isinstance(s, int), \
f"Type mismatch for shape field. Expected Tuple[int, ...], got Tuple[{type(self.shape)}]."
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.shape = set_shape_batch_size(self.shape, batch_size)
@dataclass
[docs]
class MeanQuantAttrs(AwesomeQuantAttrBase):
"""
Contains quantization attributes for mean quantization.
:param attrs: MeanAttrs used in mean operator.
:param node_scales: Scales(s) of the quantized output tensors(s).
:param node_zps: Zero points(s) of the quantized output tensors(s).
"""
[docs]
node_scales: float = 1.0
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.attrs.set_batch_size(batch_size)
@dataclass
[docs]
class ArgMaxAttrs(AwesomeAttributes):
"""
:param axis: Axis or axes along which a mean operation is performed.
:param exclude: If `exclude` is true, we use the axes that are NOT in the axis field
:param keepdims: If set to true the axes reduces are left with a size of 1
:param select_last_index: Whether to select the last index or the first index if the max element
appears in multiple indices.
:param shape: Shape of input tensor
:param result_scalar_type: Type of numbers in result tensor. It must be either
ScalarType.int32 or the same as the input tensor's type.
:param input_scalar_type: Type of input values. It must be either ScalarType.float32 or ScalarType.int8.
"""
[docs]
select_last_index: bool
[docs]
result_scalar_type: ScalarType
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.shape = set_shape_batch_size(self.shape, batch_size)
@dataclass
[docs]
class LayerNormAttrs(AwesomeAttributes):
"""
:param axis: The axis to sum over when computing mean.
:param input_shape: Shape of input tensor.
:param epsilon: The epsilon value to use to avoid division by zero.
:param scalar_type: Type of input and output.
"""
[docs]
axis: int | tuple[int, int]
[docs]
scalar_type: ScalarType
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_shape = set_shape_batch_size(self.input_shape, batch_size)
@dataclass
[docs]
class RMSNormAttrs(AwesomeAttributes):
"""
:param input_shape: Shape of input tensor.
:param epsilon: The epsilon value to use to avoid division by zero.
:param scalar_type: Type of input and output.
"""
[docs]
scalar_type: ScalarType
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_shape = set_shape_batch_size(self.input_shape, batch_size)
@dataclass
[docs]
class InstanceNormAttrs(AwesomeAttributes):
"""
Instance Normalization operator attributes.
Attributes:
axis: The axes to sum over when computing mean.
input_data_shape: Shape of the input tensor.
mean_shape: Shape of the mean input tensor.
variance_shape: Shape of the variance input tensor.
epsilon: The epsilon value to use to avoid division by zero.
scalar_type: Type of input and output.
"""
[docs]
mean_shape: tuple[int, ...]
[docs]
variance_shape: tuple[int, ...]
[docs]
scalar_type: ScalarType
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_data_shape = set_shape_batch_size(self.input_data_shape, batch_size)
self.mean_shape = set_shape_batch_size(self.mean_shape, batch_size)
self.variance_shape = set_shape_batch_size(self.variance_shape, batch_size)
@dataclass
[docs]
class SoftmaxAttrs(AwesomeAttributes):
"""
:param axis: The axis to sum over when computing softmax
:param input_shape: Shape of input tensor
:param scalar_type: Type of input and output
"""
[docs]
scalar_type: ScalarType
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_shape = set_shape_batch_size(self.input_shape, batch_size)
@dataclass
[docs]
class PadAttrs(AwesomeAttributes):
"""
:param pad_mode: 'constant', 'edge', 'reflect'
:param pad_width: padding along each input dimension N in the format of (before_N, after_N)
"""
def __post_init__(self):
assert self.pad_mode == 'constant', f"Only support 'constant' pad_mode. Got {self.pad_mode}"
assert np.sum(self.pad_width[0]) == 0, f"Don't support batch dimension padding. Got {self.pad_width}"
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_shape = set_shape_batch_size(self.input_shape, batch_size)
@dataclass
[docs]
class LRNAttrs(AwesomeAttributes):
"""
:param alpha: The scaling parameter.
:param axis: Input data layout channel axis. Default value is 1 for NCHW format
:param beta: The exponent parameter.
:param bias: The offset parameter to avoid dividing by 0.
:param size: The size of the local region to be considered for normalization.
:param shape: Shape of input tensor
# NOTES FOR TENSORFLOW
# TVM defines size as size_tvm = (depth_radius_tf * 2) + 1
# TVM defines alpha as alpha_tvm = alpha_tf * size_tf
"""
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.shape = set_shape_batch_size(self.shape, batch_size)
@dataclass
[docs]
class ClipAttrs(AwesomeAttributes):
"""
Attributes for Clip operation. Clip operation is always merged into a composite operator.
Same class is used in floating-point and quantized version.
:param a_min: min value of clip
:param a_max: max calue of clip
:param shape: Shape of input tensor
"""
[docs]
scalar_type: ScalarType
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.shape = set_shape_batch_size(self.shape, batch_size)
@dataclass
[docs]
class ExtmAttrs(AwesomeAttributes):
"""
Attributes for extremum op, can be min or max op depending on the max boolean.
:param axis: Axis or axes along which a mean operation is performed.
:param exclude: If `exclude` is true, we use the axes that are NOT in the axis field
:param keepdims: If set to true the axes reduces are left with a size of 1
:param max: If true the operation is max, if false the operation is min.
"""
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.shape = set_shape_batch_size(self.shape, batch_size)
@dataclass
[docs]
class SumAttrs(AwesomeAttributes):
"""
:param axis: Axis or axes along which a mean operation is performed.
:param exclude: If `exclude` is true, we use the axes that are NOT in the axis field
:param keepdims: If set to true the axes reduces are left with a size of 1
:param num_element: Number of element to be summed. This attribute is not a default TVM attribute.
It will be assigned during any floating point inference and used in quantization.
"""
# Not a TVM attribute
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.shape = set_shape_batch_size(self.shape, batch_size)
@dataclass
[docs]
class ProdAttrs(AwesomeAttributes):
"""
:param axis: Axis or axes along which a mean operation is performed.
:param exclude: If `exclude` is true, we use the axes that are NOT in the axis field
:param keepdims: If set to true the axes reduces are left with a size of 1
"""
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.shape = set_shape_batch_size(self.shape, batch_size)
@dataclass
[docs]
class FullAttrs(AwesomeAttributes):
"""
:param shape: The shape of the target.
:param dtype: The data type of the target.
"""
@dataclass
[docs]
class TileAttrs(AwesomeAttributes):
"""
:param reps: The number of times repeating the tensor data.
"""
@dataclass
[docs]
class UpsamplingAttrs(AwesomeAttributes):
"""
:param input_shape: Shape of the input tensor.
:param scale_h: The scale factor for height upsampling.
:param scale_w: The scale factor for width upsampling.
:param layout: Layout of the input.
:param method: Scale method to used [nearest_neighbor, bilinear, bicubic].
:param align_corners: Whether to keep corners in proper place.
:param scalar_type: Data type.
"""
[docs]
scalar_type: ScalarType
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_shape = set_shape_batch_size(self.input_shape, batch_size)
@dataclass
[docs]
class ImageResize2DAttrs(AwesomeAttributes):
"""
:param size: The out size to which the image will be resized.
:param roi: The region of interest for cropping the input image. Expected to be of size 4 and format
[start_h, start_w, end_h, end_w]. Only used if coordinate transformation_mode is
'tf_crop_and_resize'.
:param layout: Layout of the input.
:param method: Scale method to used [nearest_neighbor, linear, bicubic].
:param coordinate_transformation_mode: Describes how to transform the coordinate in the resized tensor
to the coordinate in the original tensor.
Refer to the ONNX Resize operator specification for details.
[half_pixel, align_corners, asymmetric]
:param rounding_method: (string, optional) - Indicates how to find the “nearest” pixel in nearest_neighbor
method [round, floor, ceil]
:param cubic_alpha: (float) – Spline Coefficient for Bicubic Interpolation
:param cubic_exclude: (int) – Flag to exclude exterior of the image during bicubic interpolation
:param extrapolation_value: Fill value to use when roi is outside of the image.
:param out_dtype: Type to return. If left None returns the same type as input.
"""
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_shape = set_shape_batch_size(self.input_shape, batch_size)
@dataclass
[docs]
class GridSampleAttrs(AwesomeAttributes):
"""Attributes of GridSample operator.
input_shape: Shape of the input tensor.
grid_shape: Shape of the grid tensor.
method: Interpolation method to use ["nearest", "bilinear", "bicubic"].
padding_mode: padding mode ["zeros", "border", "reflection"].
align_corners: Whether to align the corners in interpolation.
scalar_type: Data type.
"""
[docs]
grid_shape: Tuple[int, ...]
[docs]
scalar_type: ScalarType
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_shape = set_shape_batch_size(self.input_shape, batch_size)
@dataclass
[docs]
class MaximumAttrs(AwesomeAttributes):
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_shape = set_shape_batch_size(self.input_shape, batch_size)
@dataclass
[docs]
class MinimumAttrs(AwesomeAttributes):
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_shape = set_shape_batch_size(self.input_shape, batch_size)
################################
# TENSOR MANIPULATION ATTRIBUTES
################################
@dataclass
[docs]
class TensorManipulationBaseAttrs(AwesomeAttributes):
"""
Do nothing. Used for better structuring data structure
"""
@dataclass
[docs]
class TupleAttrs(AwesomeAttributes):
def __post_init__(self):
assert len(self.input_types) > 0
for it in self.input_types:
assert isinstance(it, TensorType)
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_types = [
set_tensor_type_batch_size(input_type, batch_size) for input_type in self.input_types
]
@dataclass
[docs]
class TupleGetItemAttrs(AwesomeAttributes):
"""
:param input_types: List of input tensor types
:param index: The index of the tuple_value we return
"""
def __post_init__(self):
assert len(self.input_types) > 0
for it in self.input_types:
assert isinstance(it, TensorType)
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_types = [
set_tensor_type_batch_size(input_type, batch_size) for input_type in self.input_types
]
@dataclass
[docs]
class SqueezeAttrs(TensorManipulationBaseAttrs):
"""
:param axis: Set of axes to remove
:param input_shape: Shape of input tensor
:param input_type: Data type of input tensor
"""
def __post_init__(self):
_error_msg = f"Does not support squeeze for batch dimension. Got axis={self.axis}"
assert self.axis is not None and 0 not in self.axis, _error_msg
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_shape = set_shape_batch_size(self.input_shape, batch_size)
@dataclass
[docs]
class ConcatenateAttrs(TensorManipulationBaseAttrs):
"""
:param scalar_type: Scalar tyoe of the output.
:param axis: The axis along which the tensors are concatenated.
:param input_types: List of input tensor types.
"""
[docs]
scalar_type: ScalarType
def __post_init__(self):
assert isinstance(self.axis, int)
if self.axis == 0:
raise NotImplementedError("Concatenate does not support axis along batch(0)")
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_types = [
set_tensor_type_batch_size(input_type, batch_size) for input_type in self.input_types
]
@dataclass
[docs]
class TransposeAttrs(TensorManipulationBaseAttrs):
"""
:param axes: The target axes order, reverse order if not specified.
:param input_shape: Shape of input tensor
:param input_type: Data type of input tensor
"""
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_shape = set_shape_batch_size(self.input_shape, batch_size)
@dataclass
[docs]
class DepthToSpaceAttrs(TensorManipulationBaseAttrs):
"""Attributes of DepthToSpace operator
block_size: Bolck size that is shifted from channels to height and width
mode: DCR for depth-column-row order re-arrangement, CRD for column-row-depth order
input_shape: Shape of input tensor
input_type: Data type of input tensor
"""
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_shape = set_shape_batch_size(self.input_shape, batch_size)
@dataclass
[docs]
class ReshapeAttrs(TensorManipulationBaseAttrs):
"""
:param input_shape: Shape of input tensor
:param dtype: Data type
:param newshape: The new shape.
"""
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_shape = set_shape_batch_size(self.input_shape, batch_size)
assert len(self.newshape) > 1, "ReshapeAttrs error: newshape does not have batch dimension"
self.newshape[0] = batch_size
@dataclass
[docs]
class ExpandDimsAttrs(TensorManipulationBaseAttrs):
"""
:param axis: The axis that is expanded
:param num_newaxis: The number of axes to be inserted. Should be >= 0
:param input_shape: Shape of input tensor
:param input_type: Data type of input tensor
"""
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_shape = set_shape_batch_size(self.input_shape, batch_size)
@dataclass
[docs]
class BatchFlattenAttrs(TensorManipulationBaseAttrs):
pass
@dataclass
[docs]
class SplitAttrs(TensorManipulationBaseAttrs):
"""
:param indices_or_sections: Indices or sections to split into. Accepts an int or a tuple
If indices_or_sections is an integer, the input will be divided equally along given axis.
If such a split is not possible, an error is raised.
If indices_or_sections is a tuple of sorted integers, the entries indicate where along axis the array is split.
:param axis: The axis over which to split.
:param input_shape: Shape of input tensor
:param input_type: Data type of input tensor
"""
[docs]
indices_or_sections: Union[int, Tuple[int, ...]]
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_shape = set_shape_batch_size(self.input_shape, batch_size)
@dataclass
[docs]
class TakeAttrs(TensorManipulationBaseAttrs):
"""
:param axis: The axis over which to select values. By default, the flattened input array is used.
:param mode: Specifies how out-of-bound indices will behave [clip, wrap, fast].
clip: clip to the range (default).
wrap: wrap around the indices.
fast: no clip or wrap around (user must make sure indices are in-bound).
"""
[docs]
indices_shape: Tuple[int, ...]
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_shape = set_shape_batch_size(self.input_shape, batch_size)
@dataclass
[docs]
class StridedSliceAttrs(TensorManipulationBaseAttrs):
"""
:param begin: The indices to begin with in the slicing.
:param end: Indices indicating end of the slice.
:param strides: Specifies the stride values, it can be negative in that case, the input tensor will be reversed
in that particular axis.
:param axes: Tuple[int] or List[int], optional. Axes along which slicing is applied. When it is specified, the
length of begin, end, strides, and axes must be equal. Moreover, begin, end, strides, and axes must be
static (cannot be relay.Expr). Axes argument for dynamic parameter slicing is not supported yet.
:param slice_mode: The slice mode [end, size].
end: The ending indices for the slice [default].
size: The input strides will be ignored, input end in this mode indicates
the size of a slice starting at the location specified by begin. If end[i]
is -1, all remaining elements in that dimension are included in the slice.
:param input_shape: Shape of input tensor
:param input_type: Data type of input tensor
"""
[docs]
axes: Optional[Union[Tuple[int], List[int]]]
def __post_init__(self):
if self.axes is None:
assert len(self.begin) == len(self.end) == len(self.input_shape)
else:
assert len(self.begin) == len(self.end) == len(self.axes)
if self.strides == [1] or self.strides is None:
self.strides = len(self.input_shape) * [1]
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_shape = set_shape_batch_size(self.input_shape, batch_size)
@dataclass
@dataclass
[docs]
class BroadcastToAttrs(TensorManipulationBaseAttrs):
[docs]
output_shape: Tuple[int, ...]
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_type = set_tensor_type_batch_size(self.input_type, batch_size)
self.output_shape = set_shape_batch_size(self.output_shape, batch_size)
def __post_init__(self):
input_shape = self.input_type.shape
for dim in range(len(input_shape)):
if input_shape[dim] != self.output_shape[dim] and input_shape[dim] != 1:
raise RuntimeError(f"Invalid BroadcastTo operation. Broadcasting is supported only with input"
f"dimension size 1. Input's dimension {dim} has size of {input_shape[dim]}")
@dataclass
@dataclass
@dataclass
@dataclass
@dataclass
@dataclass
@dataclass
@dataclass
@dataclass
@dataclass
@dataclass
@dataclass
@dataclass
@dataclass
[docs]
class CastAttrs(TensorManipulationBaseAttrs):
"""
:param out_dtype: The data type of the target.
:param input_shape: Shape of input tensor.
"""
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_shape = set_shape_batch_size(self.input_shape, batch_size)
###############################
# UDF(Lookup table) ATTRIBUTES
###############################
@dataclass
[docs]
class UDFAttrs(AwesomeAttributes):
"""
Common attributes for UDF functions:
* Sqrt
* Rsqrt
* Sigmoid
* Exp
* Tanh
* log, log2, log10
"""
[docs]
scalar_type: ScalarType
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_shape = set_shape_batch_size(self.input_shape, batch_size)
@dataclass
[docs]
class LeakyReluAttrs(UDFAttrs):
"""
:param alpha: The slope for the small gradient when x < 0
"""
@dataclass
[docs]
class SwishAttrs(UDFAttrs):
pass
@dataclass
[docs]
class PReluAttrs(AwesomeAttributes):
"""
:param scalar_type: Type of input and output. Must be a floating-point type.
:param axis: The axis channel dimension is specified.
:param alpha: The slope for the small gradient when x < 0 (constant tensor)
:param input_shape: Shape of input.
"""
[docs]
scalar_type: ScalarType
def __post_init__(self):
assert scalar_is_floating(self.scalar_type)
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_shape = set_shape_batch_size(self.input_shape, batch_size)
@dataclass
[docs]
class ClipQuantAttrs(AwesomeQuantAttrBase):
"""
Attributes for Clip operation. Clip operation is always merged into a composite operator.
Same class is used in floating-point and quantized version.
:param a_min: min value of clip
:param a_max: max calue of clip
:param shape: Shape of input tensor
"""
[docs]
scalar_type: ScalarType
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.shape = set_shape_batch_size(self.shape, batch_size)
@dataclass
[docs]
class ReluQuantAttrs(AwesomeQuantAttrBase):
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_shape = set_shape_batch_size(self.input_shape, batch_size)
#######################
# COMPOSITE ATTRIBUTES
#######################
# *** Note composite nodes should have their attributes ordered in topological order where earlier
# attributes belong to operations that are visited first.
# Must match the order seen in afe._tvm._tvm_dataflow_pattern.py
[docs]
ACTIVATION_ATTRS = TypeVar("ACTIVATION_ATTRS", ReluAttrs, ClipAttrs)
[docs]
QUANT_ACTIVATION_ATTRS = TypeVar("QUANT_ACTIVATION_ATTRS", ReluQuantAttrs, ClipQuantAttrs)
###################
# Add, Activations
###################
@dataclass
[docs]
class AddActivationAttrs(AwesomeAttributes, Generic[ACTIVATION_ATTRS]):
[docs]
activ_attrs: Optional[ACTIVATION_ATTRS] = None
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.add_attrs.set_batch_size(batch_size)
if self.activ_attrs is not None:
self.activ_attrs.set_batch_size(batch_size)
################################
# Convolution, Add, Activations
################################
@dataclass
[docs]
class ConvAddActivationAttrs(AwesomeAttributes, Generic[ACTIVATION_ATTRS]):
"""
Attributes of a fused convolution operator consisting of convolution, optional bias-add,
and optional activation function.
"""
[docs]
weights_attrs: ConstantAttrs
[docs]
bias_attrs: Optional[ConstantAttrs] = None
[docs]
add_attrs: Optional[Union[AddAttrs, BiasAddAttrs]] = None
[docs]
activ_attrs: Optional[ACTIVATION_ATTRS] = None
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.conv_attrs.set_batch_size(batch_size)
if self.add_attrs is not None:
self.add_attrs.set_batch_size(batch_size)
if self.activ_attrs is not None:
self.activ_attrs.set_batch_size(batch_size)
########
# Other
########
@dataclass
[docs]
class TupleConcatenateAttrs(AwesomeAttributes):
[docs]
tuple_attrs: TupleAttrs
[docs]
concat_attrs: ConcatenateAttrs
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.tuple_attrs.set_batch_size(batch_size)
self.concat_attrs.set_batch_size(batch_size)
##########################
# Partitioning Attributes
##########################
@dataclass
[docs]
class ExternalAttrs(AwesomeAttributes):
"""
:param external_input_list: Parameter names of the external code. This list must be equal to
list(node_type.keys()).
:param node_type: The external operation's type.
:param backend: The build target.
:param irmod_str: The TVM IRModule of the external code saved in string form. Code representations in other
fields are derived from this one. It has batch size 1, regardless of batch_size.
:param operations: A list of strings that detail the ops that are contained within the IRModule.
:param _graph_module: Lazily compiled executable representation of the external code. This module is used
for executing this node on the compilation host.
:param batch_size: The batch size that this node handles.
"""
_graph_module: Optional[TVMGraphModule] = None
def __post_init__(self):
assert self.external_input_list == list(self.node_type.inputs.keys()), "Type is not consistent with input list"
def __deepcopy__(self, memo):
"""
Override the deepcopy method because python doesn't have a way
to deepcopy the graph_module. Will set the deepcopied ExternalAttrs'
graph_module to None first and then call the __post_init__() to regenerate
the graph_module afterwards
"""
from copy import deepcopy
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k == "_graph_module":
setattr(result, k, None)
else:
setattr(result, k, deepcopy(v, memo))
result.__post_init__()
return result
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.node_type = set_node_type_batch_size(self.node_type, batch_size)
self.batch_size = batch_size
self._invalidate_graph_module()
def _invalidate_graph_module(self):
"""
Invalidate the compiled module. This should be called if the module should be recompiled because
some field was changed.
"""
self._graph_module = None
@property
[docs]
def graph_module(self) -> TVMGraphModule:
"""
Get the operator's code as a TVM module that can run on the compilation host.
:return: TVM Graph module
"""
if not self._graph_module:
self._graph_module = generate_graph_module(Backend.CPU, self.irmod_str, self.node_type.inputs)
return self._graph_module
############################################
# QNN Attributes
# tvm/include/tvm/relay/qnn/attrs.h
############################################
@dataclass
[docs]
class QNNQuantizeAttrs(AwesomeAttributes):
"""
Further reference: tvm/src/relay/qnn/op/quantize.cc
:param out_dtype: Specifies the output data type.
:param axis: The channel axis for quantization.
:param input_type: Tensor input type.
"""
[docs]
output_scale: np.ndarray
[docs]
output_zero_point: np.ndarray
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_type = set_tensor_type_batch_size(self.input_type, batch_size)
@dataclass
[docs]
class QNNDequantizeAttrs(AwesomeAttributes):
"""
Further reference: tvm/src/relay/qnn/op/dequantize.cc
:param axis: The channel axis for quantization.
"""
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_type = set_tensor_type_batch_size(self.input_type, batch_size)
@dataclass
[docs]
class RequantizeAttrs(AwesomeAttributes):
"""
Further reference: tvm/src/relay/qnn/op/requantize.cc
:param axis: The channel axis for quantization. This axis only apply to the input
:param rounding: Defines the rounding direction when the value is midway
between two representable values.
:param compute_dtype: Specifies the data type used during requantize.
Supported options: "int64", "float32", "float64"
:param out_dtype: Specifies the output data type.
"""
[docs]
rounding: str # TVM Requantize operator's rounding
def __post_init__(self):
assert isinstance(self.axis, int)
assert isinstance(self.rounding, str)
ScalarType.from_numpy(np.dtype(self.compute_dtype)) # Verify that compute_dtype can be interpreted as a type
ScalarType.from_numpy(np.dtype(self.out_dtype)) # Verify that out_dtype can be interpreted as a type
assert isinstance(self.input_type, TensorType)
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_type = set_tensor_type_batch_size(self.input_type, batch_size)
##############################
# Custom Operation Attributes
##############################
@dataclass
[docs]
class CustomOpAttrs(AwesomeAttributes):
"""
Custom Op AwesomeAttributes
:param custom_op_attrs: Union[str, Dic[str, Union[str, bool]]]. Custom op attrs
in either str format or a dictionary
:param c_code_in_dtypes: Optional[List[str]]. Input tensors' dtypes.
This attribute will be assigned at the runtime
:param c_code_in_shapes: Optional[List[Tuple[int, ...]]]. Input tensors' shapes
This attribute will be assigned at the runtime
:param function: Optional[OperatorFunction]. Compiled custom op C function
This attribute will be assigned at the runtime
:param args_list: Optional[Any]. A list of arguments for the custom op C function.
This attribute will be assigned at the runtime
"""
[docs]
output_types: List[TensorType]
[docs]
custom_op_attrs: Union[str, Dict[str, Union[str, bool]]]
[docs]
c_code_in_dtypes: Optional[List[str]] = None
[docs]
c_code_in_shapes: Optional[List[Tuple[int, ...]]] = None
[docs]
function: Optional[OperatorFunction] = None
[docs]
args_list: Optional[Any] = None
def __post_init__(self):
if isinstance(self.custom_op_attrs, str):
self.custom_op_attrs = parse_custom_op_attrs_to_dict(self.custom_op_attrs)
@dataclass
[docs]
class AddQuantAttrs(AwesomeQuantAttrBase):
"""
Attributes for quantized AddActivationOp.
param lhs_scale: Scale correction applied to the left-hand side input.
param rhs_scale: Scale correction applied to the right-hand side input.
:param input_int16: If True, the inputs have int16 type. If false, the inputs have int8 type.
:param requant: Requantization to perform on the output.
:param relu_zero_point: Zero point of the output for relu activation. Ignored if
relu is not used.
param layer_bits: Number of bits used to quantize output tensor.
param activ_attrs: Activation attributes used in Add composite operators.
"""
[docs]
requant: BaseRequantization[np.ndarray]
[docs]
relu_zero_point: int = 0
[docs]
activ_attrs: Optional[QUANT_ACTIVATION_ATTRS] = None
@property
[docs]
def node_scales(self) -> List[float]:
return [self.lhs_scale, self.rhs_scale]
@property
[docs]
def node_zps(self) -> List[int]:
return [self.lhs_zero_point, self.rhs_zero_point]
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.lhs_input_shape = set_shape_batch_size(self.lhs_input_shape, batch_size)
self.rhs_input_shape = set_shape_batch_size(self.rhs_input_shape, batch_size)
if self.activ_attrs is not None:
self.activ_attrs.set_batch_size(batch_size)
@dataclass
[docs]
class SubtractQuantAttrs(AwesomeQuantAttrBase):
"""
param attrs: SubtractAttrs class holding SubtractOp parameters
:param input_int16: If True, the inputs have int16 type. If False, the inputs have int8 type.
param lhs_scale: Scale correction applied to the left-hand side input.
param rhs_scale: Scale correction applied to the right-hand side input.
param layer_bits: Number of bits used to quantize output tensor.
"""
[docs]
requant: BaseRequantization[np.ndarray]
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.attrs.set_batch_size(batch_size)
@dataclass
[docs]
class ConvQuantAttrs(AwesomeQuantAttrBase):
"""
Used for all variants of convolution.
The attributes describe the following sequence of operators (some are optional).
Relu and clip are mutually exclusive.
1. convolution
2. bias_add
3. requantize
4. relu/clip
Due to limitations of how the backend is implemented, we cannot allow
the combination zero_point != 0 and isinstance(activ_attrs, ReluAttrs) and
isinstance(requant, ArithFoldedRequantization).
The quantizer must conform to this restriction.
:param conv_attrs: Attributes of the convolution operator.
:param weight_quant_data: Quantized weights data.
:param scale: Scale of the convolution operation.
:param zero_point: Zero point of the quantized output tensor.
:param input_zp: Zero point of input to the convolution.
:param bias_quant_data: Quantized bias data.
:param weight_bits: Number of bits used to quantize the weights.
:param bits: Number of bits used for quantization.
:param per_channel: If true, each output channel of the weights will have an
independent scale.
:param activ_attrs: Activation attributes.
:param requant: Requantization to do after convolve and add.
:param input_int16: Whether the input tensor has int16 type. If true,
then the operator will execute using the 15-bit convolution algorithm.
:param msb_left_shift: Whether the 15-bit convolution algorithm will
left-shift the MSB (effectively right-shifting the full product by 1).
If false, it will right-shift the LSB (effectively right-shifting the
full product by 8).
Ignored if input_int16 is False.
"""
[docs]
weight_quant_data: np.ndarray
[docs]
requant: BaseRequantization[np.ndarray]
[docs]
bias_quant_data: Optional[np.ndarray] = None
[docs]
per_channel: bool = DEFAULT_PER_CHANNEL
[docs]
activ_attrs: Optional[QUANT_ACTIVATION_ATTRS] = None
[docs]
msb_left_shift: Union[bool, np.ndarray] = True
def __post_init__(self):
assert isinstance(self.conv_attrs, ConvAttrs)
assert isinstance(self.weight_quant_data, np.ndarray)
assert isinstance(self.requant, BaseRequantization)
assert isinstance(self.scale, float)
assert isinstance(self.zero_point, int)
assert isinstance(self.input_zp, int)
assert isinstance(self.bias_quant_data, (NoneType, np.ndarray))
assert isinstance(self.per_channel, bool)
assert isinstance(
self.activ_attrs, (NoneType, ReluQuantAttrs, ClipQuantAttrs, ReluAttrs, ClipAttrs)
)
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.conv_attrs.set_batch_size(batch_size)
if self.activ_attrs is not None:
self.activ_attrs.set_batch_size(batch_size)
@dataclass
[docs]
class UpsamplingQuantAttrs(AwesomeQuantAttrBase):
"""
:param input_zp:
:param rounding_type:
"""
[docs]
upsampling_attrs: UpsamplingAttrs
[docs]
rounding_type: RoundType = RoundType.TOEVEN
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.upsampling_attrs.set_batch_size(batch_size)
@dataclass
[docs]
class ImageResize2DQuantAttrs(AwesomeQuantAttrBase):
"""
:param input_zp:
:param rounding_type:
:param requant: Requantization to perform on the output.
:param input_int16: If True, the inputs have int16 type. If False, the inputs have int8 type.
"""
[docs]
image_resize2d_attrs: ImageResize2DAttrs
[docs]
rounding_type: RoundType = RoundType.TOEVEN
[docs]
requant: Optional[BaseRequantization] = None
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.image_resize2d_attrs.set_batch_size(batch_size)
@dataclass
[docs]
class LRNQuantAttrs(AwesomeQuantAttrBase):
"""
:param axis: Input data layout channel axis. Default value is 1 for NCHW format
:param size: The size of the local region to be considered for normalization.
:param lut_scale: The scale for quantization of LUT input .
:param lut_zp_corr: The zp correction for quantization of LUT input .
:param lut_sh: The shift for quantization of LUT input .
:param output_scale: The scale for quantization of output.
:param output_zp_corr: The zp correction for quantization of output.
:param output_sh: The shift for quantization of output.
# NOTES FOR TENSORFLOW
# TVM defines size as size_tvm = (depth_radius_tf * 2) + 1
# TVM defines alpha as alpha_tvm = alpha_tf * size_tf
"""
[docs]
lookup_table: np.ndarray
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.shape = set_shape_batch_size(self.shape, batch_size)
@dataclass
[docs]
class LayerNormQuantAttrs(AwesomeQuantAttrBase):
"""
:param axis: Indicates the dimension along which LayerNorm will be performed.
:param input_shape: Input shape.
:param lookup_table_rsqrt: Look-up table f(x) = 1 / sqrt(x + epsilon).
:param zp_rsqrt: Output zero point of the Rsqrt LUT.
:param requant_mean: Requantization parameters for input mean (integer inputs only).
:param requant_lut_input: Requantization parameters for Rsqrt LUT input.
:param requant_output: Requantization of final output.
"""
[docs]
lookup_table_rsqrt: np.ndarray
[docs]
requant_mean: BaseRequantization[np.ndarray]
[docs]
requant_output: BaseRequantization[np.ndarray]
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_shape = set_shape_batch_size(self.input_shape, batch_size)
@dataclass
[docs]
class InstanceNormQuantAttrs(AwesomeQuantAttrBase):
"""
Quantized Instance Normalization operator attributes.
Attributes:
attrs: InstanceNorm attributes.
lut_rsqrt: Look-up table f(x) = 1 / sqrt(x + epsilon).
zp_rsqrt: Output zero point of the Rsqrt LUT.
requant_out: Requantization of the output.
"""
[docs]
attrs: InstanceNormAttrs
[docs]
requant_out: BaseRequantization[np.ndarray]
@dataclass
[docs]
class RMSNormQuantAttrs(AwesomeQuantAttrBase):
"""
:param input_shape: Input shape.
:param zp_ifm: Input tensor zero point.
:param lookup_table_rsqrt: Look-up table f(x) = 1 / sqrt(x + epsilon).
:param zp_rsqrt: Output zero point of the Rsqrt LUT.
:param requant_lut_input: Requantization parameters for Rsqrt LUT input.
:param requant_output: Requantization of final output.
:param lut_input_pre_shift: LUT input requantization pre-shift value.
:param output_pre_shift: Output requantization pre-shift value.
:param enable_lut_int16: If True, quantize LUT to int16 otherwise to int8.
"""
[docs]
lookup_table_rsqrt: np.ndarray
[docs]
requant_output: BaseRequantization[np.ndarray]
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_shape = set_shape_batch_size(self.input_shape, batch_size)
@dataclass
[docs]
class SoftmaxQuantAttrs(AwesomeQuantAttrBase):
"""
:param axis: Input data layout channel axis.
:param input_shape: Input shape.
:param exp_zp: Exp zero point.
:param rec_zp: Rec zero point.
:param requant_lut: Requantization parameters for quantization of reciprocal LUT input.
:param requant_output: Requantization parameters for output.
:param lookup_table_exp: LUT for exponential function.
:param lookup_table_rec: LUT for reciprocal function.
:param enable_int16: Whether int8 or int16 quantization is used.
:param lut_input_pre_shift: LUT input requantization pre-shift value (int16 only).
:param output_pre_shift: Output requantization pre-shift value (int16 only).
"""
[docs]
requant_lut: BaseRequantization[np.ndarray]
[docs]
requant_output: BaseRequantization[np.ndarray]
[docs]
lookup_table_exp: np.ndarray
[docs]
lookup_table_rec: np.ndarray
[docs]
output_pre_shift: Optional[int] = None
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_shape = set_shape_batch_size(self.input_shape, batch_size)
@dataclass
[docs]
class RequantizeQuantAttrs(AwesomeQuantAttrBase):
[docs]
requant: BaseRequantization[np.ndarray]
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.attrs.set_batch_size(batch_size)
@dataclass
[docs]
class ConcatQuantAttrs(AwesomeQuantAttrBase):
"""
Contains quantization attributes for concatenate quantization.
:param attrs: ConcatenateAttrs holding ConcatenateOp parameters.
:param requants: Requantization parameters
:param layer_bits: Number of bits used for quantizing the tensor.
:param input_scales: Quantized scale for eash inputs.
:param node_scales: Using the max input_scales as the concatenate output scale of the quantized output tensors(s).
:param node_zps: Zero points(s) of the quantized output tensors(s)
"""
[docs]
attrs: ConcatenateAttrs
[docs]
requants: List[BaseRequantization[np.ndarray]]
[docs]
layer_bits: List[int] = field(default_factory=lambda: [8])
[docs]
node_scales: List[float] = field(default_factory=list) # For Backend Model_Builder
[docs]
node_zps: List[int] = None # For graph_analyzer.
def __post_init__(self):
assert isinstance(self.attrs, ConcatenateAttrs), \
f"Type mismatch for attrs field. Got {type(self.attrs)}, expected ConcatenateAttrs."
assert isinstance(self.requants, List), \
f"Type mismatch for requants field. " \
f"Got {type(self.requants)}, expected List."
assert isinstance(self.layer_bits, List), \
f"Type mismatch for layer_bits field. Got {type(self.layer_bits)}, expected List."
assert isinstance(self.input_scales, List), \
f"Type mismatch for input_scales field. Got {type(self.input_scales)}, expected List."
assert isinstance(self.node_scales, List), \
f"Type mismatch for node_scales field. Got {type(self.node_scales)}, expected List."
assert isinstance(self.node_zps, List), \
f"Type mismatch for node_zps field. Got {type(self.node_zps)}, expected List."
assert len(self.requants) == len(self.input_scales), "Mismatch for length of List fields."
for requant in self.requants:
assert isinstance(requant, BaseRequantization), \
f"Type mismatch for input_scale_corrections field. Got {type(requant)}, expected BaseRequantization."
for in_sc in self.input_scales:
assert isinstance(in_sc, (float, List)), \
f"Type mismatch for input_scales field. Got List[{type(in_sc)}], " \
f"expected List[Union[float, [List[float]]]]."
if isinstance(in_sc, List):
for sc in in_sc:
assert isinstance(sc, float), \
f"Type mismatch for input_scales field. Got List[List[{type(sc)}]], " \
f"expected List[Union[float, [List[float]]]]."
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.attrs.set_batch_size(batch_size)
@dataclass
[docs]
class CustomOpQuantAttrs(AwesomeQuantAttrBase):
"""
Contains quantization attributes for custom operation quantization.
:param custom_op_attrs: CustomOp attributes.
:param layer_bits: Number of bits used for quantizing the tensor.
:param node_zps: Zero points(s) of the quantized output tensors(s)
:param node_scales: Output scale of the quantized output tensors(s).
:param input_zps: Quantized zero points correction each input.
:param input_scales: Quantized scales for each input.
"""
[docs]
custom_op_attrs: CustomOpAttrs
[docs]
layer_bits: List[int] = field(default_factory=lambda: [8])
[docs]
node_zps: List[int] = field(default_factory=list)
[docs]
node_scales: List[float] = field(default_factory=list)
@dataclass
[docs]
class PoolQuantAttrs(AwesomeQuantAttrBase):
"""
Contains quantization attributes for pool quantization.
:param pool_attrs: Pool attrs class holding MaxPool/AvgPoll operator parameters.
Its scalar type does not determine the scalar type for the quantized operator.
:param pad_value: Padding value.
:param rounding_type: RoundType.
"""
[docs]
pad_value: Union[float, int]
[docs]
rounding_type: RoundType
[docs]
requant: Optional[BaseRequantization] = None
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.pool_attrs.set_batch_size(batch_size)
@dataclass
[docs]
class VarianceQuantAttrs(AwesomeQuantAttrBase):
"""
Attributes:
attrs: Variance attributes.
requant: Requantiation of the intermediate values.
requant_var: Requantization of the Variance operator final output.
"""
[docs]
requant: BaseRequantization
[docs]
requant_var: BaseRequantization
@dataclass
[docs]
class UDFQuantAttrs(AwesomeQuantAttrBase):
[docs]
output_signed: bool = False
[docs]
lookup_table: Optional[np.ndarray] = None
[docs]
requant: Optional[BaseRequantization] = None
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.attrs.set_batch_size(batch_size)
@dataclass
[docs]
class DivideAttrs(AwesomeAttributes):
[docs]
multiply_attrs: MultiplyAttrs
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.udf_attrs.set_batch_size(batch_size)
self.multiply_attrs.set_batch_size(batch_size)
@dataclass
[docs]
class DivideQuantAttrs(AwesomeQuantAttrBase):
[docs]
udf_attrs: UDFQuantAttrs
[docs]
multiply_attrs: MultiplyQuantAttrs
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.udf_attrs.set_batch_size(batch_size)
self.multiply_attrs.set_batch_size(batch_size)
@dataclass
[docs]
class LeakyReluCompositeQuantAttrs(AwesomeQuantAttrBase):
"""
Contains quantization attributes for both UDF and breakdown LeakyRelu quantization.
:param attrs: LeakyRelu attributes class holding LeakyReluOp parameters.
:param leaky_relu_uses_udf: bool. If True, use UDF version in quantization. Otherwise, use breakdown version.
:param leaky_relu_quant_attrs: Contains quantization parameters for breakdown version if breakdown version is used.
:param udf_quant_attrs: Contains quantization parameters for UDF version if UDF version is used.
"""
[docs]
leaky_relu_uses_udf: bool = True
[docs]
leaky_relu_quant_attrs: Optional[LeakyReluQuantAttrs] = None
[docs]
udf_quant_attrs: Optional[UDFQuantAttrs] = None
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.attrs.set_batch_size(batch_size)
if self.leaky_relu_quant_attrs is not None:
self.leaky_relu_quant_attrs.set_batch_size(batch_size)
if self.udf_quant_attrs is not None:
self.udf_quant_attrs.set_batch_size(batch_size)
@dataclass
[docs]
class PReluQuantAttrs(AwesomeQuantAttrBase):
"""
The slope for quantized_intput < zero_point is (alpha >> right_shift)
"""
[docs]
quant_alpha: np.ndarray
# PRelu has the same zero point as its input
[docs]
data_zero_point: int = 0
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_shape = set_shape_batch_size(self.input_shape, batch_size)
@dataclass
[docs]
class PowerAttrs(AwesomeAttributes):
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.lhs_input_shape = set_shape_batch_size(self.lhs_input_shape, batch_size)
self.rhs_input_shape = set_shape_batch_size(self.rhs_input_shape, batch_size)
@dataclass
[docs]
class ArgMaxQuantAttrs(AwesomeQuantAttrBase):
# Use the same format as ArgMax for attributes
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.attrs.set_batch_size(batch_size)
@dataclass
[docs]
class BatchMatmulAttrs(AwesomeAttributes):
[docs]
scalar_type: ScalarType
def __post_init__(self):
assert len(self.input_shapes) == 2, \
"Unsupported number of input data operands in einsum operator"
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_shapes = [set_shape_batch_size(in_shape, batch_size)
for in_shape in self.input_shapes]
[docs]
def get_output_shape(self) -> Tuple[int, ...]:
nhw_out = self.input_shapes[0][0:3]
c_out = self.input_shapes[1][2] if self.transpose_b else self.input_shapes[1][3]
return *nhw_out, c_out
@dataclass
[docs]
class BatchMatmulQuantAttrs(AwesomeQuantAttrBase):
[docs]
attrs: BatchMatmulAttrs
[docs]
requant: BaseRequantization
[docs]
def set_batch_size(self, batch_size: int):
self.attrs.set_batch_size(batch_size)
@dataclass
[docs]
class SliceConcatAttrs(AwesomeAttributes):
[docs]
slice_attrs: List[StridedSliceAttrs]
[docs]
tuple_concat_attrs: TupleConcatenateAttrs
@dataclass
[docs]
class SliceConcatQuantAttrs(AwesomeQuantAttrBase):
[docs]
slice_attrs: List[StridedSliceAttrs]
[docs]
tuple_concat_attrs: ConcatQuantAttrs
@dataclass
[docs]
class BroadcastToQuantAttrs(AwesomeQuantAttrBase):
[docs]
output_shape: Tuple[int, ...]
[docs]
def set_batch_size(self, batch_size: int):
"""
Modify internal parameters' shapes for the given batch size.
"""
self.input_type = set_tensor_type_batch_size(self.input_type, batch_size)
self.output_shape = set_shape_batch_size(self.output_shape, batch_size)
def __post_init__(self):
input_shape = self.input_type.shape
for dim in range(len(input_shape)):
if input_shape[dim] != self.output_shape[dim] and input_shape[dim] != 1:
raise RuntimeError(f"Invalid BroadcastTo operation. Broadcasting is supported only with input"
f"dimension size 1. Input's dimension {dim} has size of {input_shape[dim]}")