Source code for sima_qat.onnx_ops

#**************************************************************************
#||                        SiMa.ai CONFIDENTIAL                          ||
#||   Unpublished Copyright (c) 2024 SiMa.ai, All Rights Reserved.       ||
#**************************************************************************
# NOTICE:  All information contained herein is, and remains the property of
# SiMa.ai. The intellectual and technical concepts contained herein are
# proprietary to SiMa and may be covered by U.S. and Foreign Patents,
# patents in process, and are protected by trade secret or copyright law.
#
# Dissemination of this information or reproduction of this material is
# strictly forbidden unless prior written permission is obtained from
# SiMa.ai.  Access to the source code contained herein is hereby forbidden
# to anyone except current SiMa.ai employees, managers or contractors who
# have executed Confidentiality and Non-disclosure agreements explicitly
# covering such access.
#
# The copyright notice above does not evidence any actual or intended
# publication or disclosure  of  this source code, which includes information
# that is confidential and/or proprietary, and is a trade secret, of SiMa.ai.
#
# ANY REPRODUCTION, MODIFICATION, DISTRIBUTION, PUBLIC PERFORMANCE, OR PUBLIC
# DISPLAY OF OR THROUGH USE OF THIS SOURCE CODE WITHOUT THE EXPRESS WRITTEN
# CONSENT OF SiMa.ai IS STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE
# LAWS AND INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE
# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS TO
# REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE, USE, OR
# SELL ANYTHING THAT IT  MAY DESCRIBE, IN WHOLE OR IN PART.
#
#**************************************************************************
import functools

import torch
import torch._C._onnx as _C_onnx
import torch.nn.modules.utils
import torch.onnx
from torch.onnx import (
    _type_utils,
    errors,
    symbolic_helper,
)


from torch.onnx._internal import _beartype, jit_utils, registration


# Q/DQ operators in ONNX had a major revision at Opset 13. Opset 19 was the next revision,
# and those changes are not relevant for Sima.
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=13)



@_onnx_symbolic("quantized_decomposed::quantize_per_tensor")
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "v")
@_beartype.beartype
[docs] def fake_quantize_per_tensor_affine( g: jit_utils.GraphContext, inputs, scale, zero_point, quant_min=-128, quant_max=127, dtype=torch.dtype, ): # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: raise errors.SymbolicValueError( "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). " f"Got ({quant_min}, {quant_max})", inputs, ) if quant_min == 0: zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) else: zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) # Assert-based for now; this will catch logical graph problems. input_type = _type_utils.JitScalarType.from_value(inputs, _type_utils.JitScalarType.UNDEFINED) assert input_type == _type_utils.JitScalarType.FLOAT # This is apparently important, because scale can come is as a double (!?) on this function call. if ( _type_utils.JitScalarType.from_value(scale, _type_utils.JitScalarType.UNDEFINED) != _type_utils.JitScalarType.FLOAT ): scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) quantized = g.op("QuantizeLinear", inputs, scale, zero_point) assert _type_utils.JitScalarType.from_value(quantized, _type_utils.JitScalarType.UNDEFINED) == _type_utils.JitScalarType.INT8 return quantized
@_onnx_symbolic("quantized_decomposed::dequantize_per_tensor") @symbolic_helper.parse_args("v", "v", "v", "i", "i", "v") @_beartype.beartype
[docs] def fake_dequantize_per_tensor_affine( g: jit_utils.GraphContext, inputs, scale, zero_point, quant_min=-128, quant_max=127, dtype=torch.dtype, ): if quant_min == 0: zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) else: zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) input_type = _type_utils.JitScalarType.from_value(inputs, _type_utils.JitScalarType.UNDEFINED) assert input_type == _type_utils.JitScalarType.INT8 quantized = inputs # This is apparently important, because scale can come is as a double (!?) on this function call. if ( _type_utils.JitScalarType.from_value(scale, _type_utils.JitScalarType.UNDEFINED) != _type_utils.JitScalarType.FLOAT ): scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) dq = g.op("DequantizeLinear", quantized, scale, zero_point) assert _type_utils.JitScalarType.from_value(dq, _type_utils.JitScalarType.UNDEFINED) == _type_utils.JitScalarType.FLOAT return dq
@_onnx_symbolic("quantized_decomposed::dequantize_per_channel") @symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v") @_beartype.beartype
[docs] def fake_quantize_per_channel_affine( g: jit_utils.GraphContext, inputs, scales, zero_points, axis, quant_min=-128, quant_max=127, dtype=torch.dtype, ): # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 # if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: # raise errors.SymbolicValueError( # "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). " # f"Got ({quant_min}, {quant_max})", # inputs, # ) # ONNX defines zero_point to be int8 or uint8 if quant_min == 0: zero_points = g.op("Cast", zero_points, to_i=_C_onnx.TensorProtoDataType.UINT8) else: zero_points = g.op("Cast", zero_points, to_i=_C_onnx.TensorProtoDataType.INT8) input_type = _type_utils.JitScalarType.from_value(inputs, _type_utils.JitScalarType.UNDEFINED) assert input_type == _type_utils.JitScalarType.INT8 quantized = inputs # axis_cast = g.op("Cast", axis, to_i=_C_onnx.TensorProtoDataType.INT32) # This is apparently important, because scale can come is as a double (!?) on this function call. if ( _type_utils.JitScalarType.from_value(scales, _type_utils.JitScalarType.UNDEFINED) != _type_utils.JitScalarType.FLOAT ): scales = g.op("Cast", scales, to_i=_C_onnx.TensorProtoDataType.FLOAT) dq = g.op("DequantizeLinear", quantized, scales, zero_points, axis_i=axis) assert _type_utils.JitScalarType.from_value(dq, _type_utils.JitScalarType.UNDEFINED) == _type_utils.JitScalarType.FLOAT return dq