#**************************************************************************
#|| SiMa.ai CONFIDENTIAL ||
#|| Unpublished Copyright (c) 2024 SiMa.ai, All Rights Reserved. ||
#**************************************************************************
# NOTICE: All information contained herein is, and remains the property of
# SiMa.ai. The intellectual and technical concepts contained herein are
# proprietary to SiMa and may be covered by U.S. and Foreign Patents,
# patents in process, and are protected by trade secret or copyright law.
#
# Dissemination of this information or reproduction of this material is
# strictly forbidden unless prior written permission is obtained from
# SiMa.ai. Access to the source code contained herein is hereby forbidden
# to anyone except current SiMa.ai employees, managers or contractors who
# have executed Confidentiality and Non-disclosure agreements explicitly
# covering such access.
#
# The copyright notice above does not evidence any actual or intended
# publication or disclosure of this source code, which includes information
# that is confidential and/or proprietary, and is a trade secret, of SiMa.ai.
#
# ANY REPRODUCTION, MODIFICATION, DISTRIBUTION, PUBLIC PERFORMANCE, OR PUBLIC
# DISPLAY OF OR THROUGH USE OF THIS SOURCE CODE WITHOUT THE EXPRESS WRITTEN
# CONSENT OF SiMa.ai IS STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE
# LAWS AND INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE
# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS TO
# REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE, USE, OR
# SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART.
#
#**************************************************************************
from __future__ import annotations
import copy
import operator
import functools
import itertools
from typing import Any, Callable, Dict, List, Optional, Set
import torch
import torch._dynamo as torchdynamo
import torch.nn.functional as F
from torch.ao.quantization.fake_quantize import (
FakeQuantize,
FusedMovingAvgObsFakeQuantize,
)
from torch.ao.quantization.observer import (
HistogramObserver,
MinMaxObserver,
MovingAverageMinMaxObserver,
MovingAveragePerChannelMinMaxObserver,
PerChannelMinMaxObserver,
PlaceholderObserver,
)
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
from torch.ao.quantization.quantizer import (
QuantizationSpec,
Quantizer,
QuantizationAnnotation,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
_convert_scalars_to_attrs,
OP_TO_ANNOTATOR,
OperatorConfig,
OperatorPatternType,
propagate_annotation,
QuantizationConfig,
_is_annotated,
get_input_act_qspec,
get_output_act_qspec,
register_annotator,
_is_input_non_float_tensor,
_is_input_large_scalar,
get_weight_qspec,
get_bias_qspec,
_mark_nodes_as_annotated,
_WrapperModule
)
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
_get_module_type_filter,
_get_dynamo_graph,
_get_linear_patterns,
_get_module_name_filter,
_get_module_type_filter,
_get_not_module_type_or_name_filter,
)
from torch.ao.quantization.pt2e.utils import (
_conv1d_bn_example_inputs,
_conv2d_bn_example_inputs,
get_aten_graph_module,
)
from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
SubgraphMatcherWithNameNodeMap,
)
from torch.fx import Node
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
# from torch.ops.quantized_decomposed import quantize_per_tensor
__all__ = [
"SimaQuantizer",
"get_sima_quantization_config",
]
class SimaMovingAverageMinMaxObserver(MovingAverageMinMaxObserver):
r"""
We override the vanilla Pytorch MinMaxObserver with a specialized version. This
specialized version averages min/max over each sample in the batch. This gives
more consistency across samples, and makes training behavior more independent
from the batch size setting.
"""
def forward(self, x_orig):
if x_orig.numel() == 0:
return x_orig
x = x_orig.detach() # avoid keeping autograd tape
x = x.to(self.min_val.dtype)
min_val = self.min_val
max_val = self.max_val
if min_val == float("inf") and max_val == float("-inf"):
min_val, max_val = torch.aminmax(x)
else:
if len(x.shape) < 2:
min_val_cur, max_val_cur = torch.aminmax(x)
else:
# min_val_cur, max_val_cur = torch.aminmax(x)
if x.is_contiguous():
x_flat = x.view(x.shape[0], -1)
else:
x_flat = x.reshape(x.shape[0], -1)
min_val_batch, max_val_batch = torch.aminmax(x_flat, dim=1)
min_val_cur = min_val_batch.mean()
max_val_cur = max_val_batch.mean()
min_val = min_val + self.averaging_constant * (min_val_cur - min_val)
max_val = max_val + self.averaging_constant * (max_val_cur - max_val)
self.min_val.copy_(min_val)
self.max_val.copy_(max_val)
return x_orig
def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
supported_operators: Dict[str, List[OperatorPatternType]] = {
# Both conv and linear should be able to handle relu + hardtanh fusion since
# those are clamp ops
"conv2d": [
[torch.nn.Conv2d, torch.nn.ReLU],
[torch.nn.Conv2d, F.relu],
[F.conv2d, torch.nn.ReLU],
[F.conv2d, F.relu],
],
"linear": [[torch.nn.Linear], [F.linear]],
"add": [[torch.add]],
"max_pool2d": [[torch.nn.MaxPool2d], [F.max_pool2d]],
"adaptive_avg_pool2d": [
[torch.nn.AdaptiveAvgPool2d],
[F.adaptive_avg_pool2d],
],
}
return copy.deepcopy(supported_operators)
def _get_supported_symmetric_config_and_operators() -> List[OperatorConfig]:
supported_config_and_operators: List[OperatorConfig] = []
for quantization_config in [
get_sima_quantization_config(),
get_sima_quantization_config(is_qat=True),
]:
ops = _supported_symmetric_quantized_operators()
for pattern_list in ops.values():
supported_config_and_operators.append(
OperatorConfig(quantization_config, pattern_list)
)
return copy.deepcopy(supported_config_and_operators)
@functools.lru_cache
[docs]
def get_sima_quantization_config(
is_qat: bool = False,
):
# This configuration function only has one parameter (use QAT or not).
# Sima has a preferred encoding for activation and weight tensors that give
# best possible results. Since QAT is a high-effort activity, we only use the
# best quantization settings possible here.
#
# Activations
# ---------------------------------------------------
act_extra_args: Dict[str, Any] = {"eps": 2**-12}
if is_qat:
act_observer_or_fake_quant_ctr = FakeQuantize
act_extra_args["observer"] = SimaMovingAverageMinMaxObserver
else:
# If QAT is disabled, we can add histogram observers to collect data.
act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]
# Activations have a specific set of params that we don't need to change.
# This is always per-tensor, signed integer encoding.
#
act_quantization_spec = QuantizationSpec(
dtype = torch.int8,
quant_min = -128,
quant_max = 127,
qscheme = torch.per_tensor_affine,
is_dynamic = False,
observer_or_fake_quant_ctr = act_observer_or_fake_quant_ctr.with_args(
**act_extra_args,
),
)
# Weights
# ---------------------------------------------------
# Weights will always be captured as per-channel symmetric.
wt_extra_args: Dict[str, Any] = {"eps": 2**-12}
if is_qat:
weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
else:
weight_observer_or_fake_quant_ctr = PlaceholderObserver
weight_quantization_spec = QuantizationSpec(
dtype = torch.int8,
quant_min = -127, # use -128 ??
quant_max = 127,
qscheme = torch.per_channel_symmetric,
ch_axis = 0, # Weights are always channel-first
is_dynamic = False,
observer_or_fake_quant_ctr = weight_observer_or_fake_quant_ctr.with_args(
**wt_extra_args
),
)
bias_quantization_spec = None
quantization_config = QuantizationConfig(
input_activation = act_quantization_spec,
output_activation = act_quantization_spec,
weight = weight_quantization_spec,
bias = bias_quantization_spec,
is_qat = is_qat,
)
return quantization_config
def _get_supported_config_and_operators() -> List[OperatorConfig]:
return _get_supported_symmetric_config_and_operators()
[docs]
class SimaQuantizer(Quantizer):
""" This quantizer definition uses XNNPACK implementation for the majority of ops. This is because
the XNNPACK code simply looks for appropriate patterns, and applies the QuantizationAnnotation
attributes accordingly. The quantization rules are defined separately, and specified here using
Sima properties.
"""
[docs]
supported_config_and_operators = _get_supported_config_and_operators()
[docs]
STATIC_QAT_ONLY_OPS = [
"sima_conv_bn_hardtanh",
"conv_bn_relu",
"conv_bn",
]
# static quantization ops (both PTQ and QAT)
# Preserve the order that fusions come before singular ops
[docs]
STATIC_OPS = [
"linear_relu",
"linear",
"sima_conv_add_or_mul_const",
"sima_conv_hardtanh",
"conv_relu",
"conv",
"adaptive_avg_pool2d",
"max_pool2d",
"sima_add_hardtanh",
"add_relu",
"add",
"mul_relu",
"mul",
"sima_cat",
"sima_sigmoid",
"sima_silu",
"sima_slice_select_unsqueeze",
"sima_batchnorm"
]
def __init__(self):
super().__init__()
[docs]
self.global_config: Optional[QuantizationConfig] = None
[docs]
self.operator_type_config: Dict[
torch._ops.OpOverloadPacket, Optional[QuantizationConfig]
] = {}
[docs]
self.module_type_config: Dict[Callable, Optional[QuantizationConfig]] = {}
[docs]
self.module_name_config: Dict[str, Optional[QuantizationConfig]] = {}
@classmethod
[docs]
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
op_configs: Set[QuantizationConfig] = set({})
for spec, _ in cls.supported_config_and_operators:
op_configs.add(spec)
return list(op_configs)
@classmethod
[docs]
def get_supported_operator_for_quantization_config(
cls, quantization_config: Optional[QuantizationConfig]
) -> List[OperatorPatternType]:
if quantization_config is None:
all_ops = []
for _, ops in cls.supported_config_and_operators:
all_ops.extend(ops)
return all_ops
for config, ops in cls.supported_config_and_operators:
# note: this assumes each entry in cls.supported_spec_and_operators
# corresponds to one spec, e.g. we don't have
# [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)]
# where the first and second entry have the same spec but did not
# merge the op list
if config == quantization_config:
return ops
return []
[docs]
def set_global(self, quantization_config: QuantizationConfig) -> SimaQuantizer:
self.global_config = quantization_config
return self
[docs]
def set_operator_type(
self,
operator_type: torch._ops.OpOverloadPacket,
quantization_config: QuantizationConfig,
) -> SimaQuantizer:
self.operator_type_config[operator_type] = quantization_config
return self
[docs]
def set_module_type(
self, module_type: Callable, quantization_config: QuantizationConfig
):
"""Set quantization_config for a submodule with type: `module_type`, for example:
quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it will quantize all supported operator/operator
patterns in the submodule with this module type with the given `quantization_config`
"""
self.module_type_config[module_type] = quantization_config
return self
[docs]
def set_module_name(
self, module_name: str, quantization_config: Optional[QuantizationConfig]
):
"""Set quantization_config for a submodule with name: `module_name`, for example:
quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator
patterns in the submodule with this module name with the given `quantization_config`
"""
assert (
quantization_config is not None
), " quantization_config == None is not supported yet"
self.module_name_config[module_name] = quantization_config
return self
[docs]
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""just handling global spec for now"""
# Dynamic is unsupported.
if self.global_config and self.global_config.input_activation.is_dynamic: # type: ignore[union-attr]
assert False, "Error: dynamic quantization is unsupported on Sima models."
model = self._annotate_for_static_quantization_config(model)
propagate_annotation(model)
return model
def _annotate_all_static_patterns(
self,
model: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> torch.fx.GraphModule:
# TODO: implement the support for None to be canceling out previous annotations
if quantization_config is None:
return model
if quantization_config.is_qat:
for op in self.STATIC_QAT_ONLY_OPS:
OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
for op in self.STATIC_OPS:
OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
return model
def _annotate_for_static_quantization_config(
self, model: torch.fx.GraphModule
) -> torch.fx.GraphModule:
module_name_list = list(self.module_name_config.keys())
for module_name, config in self.module_name_config.items():
self._annotate_all_static_patterns(
model, config, _get_module_name_filter(module_name)
)
tp_list = list(self.module_type_config.keys())
for module_type, config in self.module_type_config.items():
self._annotate_all_static_patterns(
model, config, _get_module_type_filter(module_type)
)
self._annotate_all_static_patterns(
model,
self.global_config,
_get_not_module_type_or_name_filter(tp_list, module_name_list),
)
return model
[docs]
def validate(self, model: torch.fx.GraphModule) -> None:
pass
@classmethod
[docs]
def get_supported_operators(cls) -> List[OperatorConfig]:
return cls.supported_config_and_operators
def _annotate_single_op(
quantization_config: QuantizationConfig,
op_partitions: List[object],
op_check: Callable,
) -> List[List[Node]]:
""" This is a helper function which annotates a single operation in a graph with Fakequant
observers. This function assumes single-input operators, and is not suitable for multi-input
ops which may have constant inputs (e.g. Conv2D).
"""
annotated_partitions = []
for op_partition in op_partitions:
op_node = op_partition.output_nodes[0]
if _is_annotated([op_node]):
continue
if not op_check(op_node):
continue
annotated_partitions.append(op_partition.nodes)
input_act_qspec = get_input_act_qspec(quantization_config)
input_act0 = op_node.args[0]
input_qspec_map = {}
if isinstance(input_act0, Node):
input_qspec_map[input_act0] = input_act_qspec
output_act_qspec = get_output_act_qspec(quantization_config)
op_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=output_act_qspec,
_annotated=True,
)
return annotated_partitions
@register_annotator("sima_sigmoid")
def _sima_annotate_sigmoid(
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
sig_partitions = get_source_partitions(gm.graph, [torch.sigmoid, F.sigmoid, torch.nn.Sigmoid], filter_fn)
sig_partitions = list(itertools.chain.from_iterable(sig_partitions.values()))
def _sig_target_check(sig_node: Node) -> bool:
if sig_node.target != torch.ops.aten.sigmoid.default:
# TODO: change this to AnnotationException
raise Exception(
f"Expected sigmoid node: torch.ops.aten.sigmoid.default, but found {sig_node.target}"
" please check if you are calling the correct capture API"
)
return True
return _annotate_single_op(
quantization_config = quantization_config,
op_partitions = sig_partitions,
op_check = _sig_target_check,
)
@register_annotator("sima_silu")
def _sima_annotate_silu(
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
silu_partitions = get_source_partitions(gm.graph, [F.silu, torch.nn.SiLU], filter_fn)
silu_partitions = list(itertools.chain.from_iterable(silu_partitions.values()))
def _silu_target_check(silu_node: Node) -> bool:
if silu_node.target not in [torch.ops.aten.silu_.default, torch.ops.aten.silu.default]:
# TODO: change this to AnnotationException
raise Exception(
f"Expected SiLU node: torch.ops.aten.silu_.default, but found {silu_node.target}"
" please check if you are calling the correct capture API"
)
return True
return _annotate_single_op(
quantization_config = quantization_config,
op_partitions = silu_partitions,
op_check = _silu_target_check,
)
@register_annotator("sima_add_hardtanh")
def _sima_annotate_add_hardtanh(
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
fused_partitions = find_sequential_partitions(
gm, [torch.add, torch.nn.Hardtanh], filter_fn=filter_fn
)
annotated_partitions = []
for fused_partition in fused_partitions:
add_partition, hardtanh_partition = fused_partition
annotated_partitions.append(add_partition.nodes + hardtanh_partition.nodes)
if len(hardtanh_partition.output_nodes) > 1:
raise ValueError("hardtanh partition has more than one output node")
hardtanh_node = hardtanh_partition.output_nodes[0]
if len(add_partition.output_nodes) > 1:
raise ValueError("add partition has more than one output node")
add_node = add_partition.output_nodes[0]
if _is_annotated([hardtanh_node, add_node]):
continue
input_act_qspec = get_input_act_qspec(quantization_config)
output_act_qspec = get_output_act_qspec(quantization_config)
input_qspec_map = {}
input_act0 = add_node.args[0]
if isinstance(input_act0, Node):
if _is_input_large_scalar(input_act0, gm):
continue
if _is_input_non_float_tensor(input_act0):
continue
input_qspec_map[input_act0] = input_act_qspec
input_act1 = add_node.args[1]
if isinstance(input_act1, Node):
if _is_input_large_scalar(input_act1, gm):
continue
if _is_input_non_float_tensor(input_act1):
continue
input_qspec_map[input_act1] = input_act_qspec
add_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
_annotated=True,
)
hardtanh_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=output_act_qspec,
_annotated=True,
)
return annotated_partitions
@register_annotator("sima_conv_hardtanh")
def _sima_annotate_conv_hardtanh(
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
annotated_partitions = []
for n in gm.graph.nodes:
if n.op != "call_function" or n.target not in [
torch.ops.aten.hardtanh.default,
torch.ops.aten.hardtanh_.default,
]:
continue
hardtanh_node = n
maybe_conv_node = n.args[0]
if (
not isinstance(maybe_conv_node, Node)
or maybe_conv_node.op != "call_function"
or maybe_conv_node.target
not in [
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
]
):
continue
conv_node = maybe_conv_node
input_qspec_map = {}
input_act = conv_node.args[0]
assert isinstance(input_act, Node)
input_qspec_map[input_act] = get_input_act_qspec(quantization_config)
weight = conv_node.args[1]
assert isinstance(weight, Node)
input_qspec_map[weight] = get_weight_qspec(quantization_config)
# adding weight node to the partition as well
partition = [hardtanh_node, conv_node, conv_node.args[1]]
bias = conv_node.args[2] if len(conv_node.args) > 2 else None
if isinstance(bias, Node):
input_qspec_map[bias] = get_bias_qspec(quantization_config)
partition.append(bias)
if _is_annotated(partition):
continue
if filter_fn and any(not filter_fn(n) for n in partition):
continue
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map, _annotated=True
)
hardtanh_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
_annotated=True,
)
_mark_nodes_as_annotated(partition)
annotated_partitions.append(partition)
return annotated_partitions
@register_annotator("sima_conv_bn_hardtanh")
def _sima_annotate_conv_bn_hardtanh(
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
"""
Find conv + batchnorm + hardtanh parititions
Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
"""
def get_pattern(conv_fn: Callable, hardtanh_is_inplace: bool):
def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv):
conv = conv_fn(x, conv_weight, conv_bias)
bn = F.batch_norm(conv, bn_rm, bn_rv, bn_weight, bn_bias, training=True)
output = F.hardtanh_(bn) if hardtanh_is_inplace else F.hardtanh(bn)
return output, {
"input": x,
"conv": conv,
"weight": conv_weight,
"bias": conv_bias,
"output": output,
}
return _WrapperModule(_conv_bn)
# Needed for matching, otherwise the matches gets filtered out due to unused
# nodes returned by batch norm
gm.graph.eliminate_dead_code()
gm.recompile()
matches = []
combinations = [
(F.conv1d, _conv1d_bn_example_inputs),
(F.conv2d, _conv2d_bn_example_inputs),
]
# Add `is_cuda` and `hardtanh_is_inplace` dimensions
combinations = itertools.product(
combinations,
[True, False] if torch.cuda.is_available() else [False], # is_cuda
[True, False], # hardtanh_is_inplace
)
# Match against all conv dimensions and cuda variants
for (conv_fn, example_inputs), is_cuda, hardtanh_is_inplace in combinations:
pattern = get_pattern(conv_fn, hardtanh_is_inplace)
pattern = get_aten_graph_module(pattern, example_inputs, is_cuda)
pattern.graph.eliminate_dead_code()
pattern.recompile()
matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True)
matches.extend(matcher.match(gm.graph))
# Annotate nodes returned in the matches
annotated_partitions = []
for match in matches:
name_node_map = match.name_node_map
input_node = name_node_map["input"]
conv_node = name_node_map["conv"]
weight_node = name_node_map["weight"]
bias_node = name_node_map["bias"]
output_node = name_node_map["output"]
# TODO: annotate the uses of input, weight, and bias separately instead
# of assuming they come from a single conv node. This is not possible today
# because input may have multiple users, and we can't rely on the conv node
# always being the first user. This was the case in models with skip
# connections like resnet18
# Validate conv args
if conv_node.args[0] is not input_node:
raise ValueError("Conv arg did not contain input node ", input_node)
if conv_node.args[1] is not weight_node:
raise ValueError("Conv arg did not contain weight node ", weight_node)
if len(conv_node.args) > 2 and conv_node.args[2] is not bias_node:
raise ValueError("Conv arg did not contain bias node ", bias_node)
# Skip if the partition is already annotated or is filtered out by the user
partition = [conv_node, weight_node]
if bias_node is not None:
partition.append(bias_node)
if _is_annotated(partition):
continue
if filter_fn and any(not filter_fn(n) for n in partition):
continue
# Annotate conv inputs and pattern output
input_qspec_map = {}
input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
input_qspec_map[weight_node] = get_weight_qspec(quantization_config)
if bias_node is not None:
input_qspec_map[bias_node] = get_bias_qspec(quantization_config)
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
_annotated=True,
)
output_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
_annotated=True,
)
_mark_nodes_as_annotated(partition)
annotated_partitions.append(partition)
return annotated_partitions
@register_annotator("sima_conv_add_or_mul_const")
def _sima_annotate_conv_add_or_mul_const(
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
annotated_partitions = []
for n in gm.graph.nodes:
if n.op != "call_function" or n.target not in [
torch.ops.aten.add.Tensor,
torch.ops.aten.mul.Tensor,
]:
continue
#check if any args is a constant
if n.args[0].op == "get_attr":
conv_node_id = 1
elif n.args[1].op == "get_attr":
conv_node_id = 0
else:
continue
op_node = n
maybe_conv_node = n.args[conv_node_id]
if (
not isinstance(maybe_conv_node, Node)
or maybe_conv_node.op != "call_function"
or maybe_conv_node.target
not in [
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
]
):
continue
conv_node = maybe_conv_node
input_qspec_map = {}
input_act = conv_node.args[0]
assert isinstance(input_act, Node)
input_qspec_map[input_act] = get_input_act_qspec(quantization_config)
weight = conv_node.args[1]
assert isinstance(weight, Node)
input_qspec_map[weight] = get_weight_qspec(quantization_config)
# adding weight node to the partition as well
partition = [op_node, conv_node, conv_node.args[1]]
bias = conv_node.args[2] if len(conv_node.args) > 2 else None
if isinstance(bias, Node):
input_qspec_map[bias] = get_bias_qspec(quantization_config)
partition.append(bias)
if _is_annotated(partition):
continue
if filter_fn and any(not filter_fn(n) for n in partition):
continue
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map, _annotated=True
)
op_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
_annotated=True,
)
_mark_nodes_as_annotated(partition)
annotated_partitions.append(partition)
return annotated_partitions
@register_annotator("sima_slice_select_unsqueeze")
def _sima_annotate_slice_select_unsqueeze(
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
annotated_partitions = []
input_act_qspec = get_input_act_qspec(quantization_config)
output_act_qspec = get_output_act_qspec(quantization_config)
for node in gm.graph.nodes:
if node.op != "call_function" or node.target not in [
torch.ops.aten.unsqueeze.default,
]:
continue
unsqueeze_node = node
maybe_select_node = node.args[0]
if (
not isinstance(maybe_select_node, Node)
or maybe_select_node.op != "call_function"
or maybe_select_node.target != torch.ops.aten.select.int
):
continue
select_node = maybe_select_node
maybe_slice_node = select_node.args[0]
if (
not isinstance(maybe_slice_node, Node)
or maybe_slice_node.op != "call_function"
or maybe_slice_node.target != torch.ops.aten.slice.Tensor
):
continue
slice_node = maybe_slice_node
input_qspec_map = {}
input_act = slice_node.args[0]
assert isinstance(input_act, Node)
input_qspec_map[input_act] = input_act_qspec
# adding weight node to the partition as well
partition = [unsqueeze_node, select_node, slice_node]
if _is_annotated(partition):
continue
if filter_fn and any(not filter_fn(n) for n in partition):
continue
slice_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
_annotated=True,
)
unsqueeze_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=output_act_qspec,
_annotated=True,
)
_mark_nodes_as_annotated(partition)
annotated_partitions.append(partition)
return annotated_partitions
@register_annotator("sima_dropout")
def _sima_annotate_dropout(
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
annotated_partitions = []
input_act_qspec = get_input_act_qspec(quantization_config)
output_act_qspec = get_output_act_qspec(quantization_config)
dropout_nodes = []
for node in gm.graph.nodes:
if node.op != "call_function" or node.target not in [
torch.ops.aten.dropout.default,
torch.ops.aten.dropout_.default,
]:
continue
dropout_node = node
dropout_nodes.append(dropout_node)
for node in gm.graph.nodes:
if node.op != "call_function" or node.args[0] not in dropout_nodes:
continue
node_after = node
dropout_node = node.args[0]
input_qspec_map = {}
input_act = dropout_node.args[0]
assert isinstance(input_act, Node)
input_qspec_map[input_act] = input_act_qspec
# adding weight node to the partition as well
partition = [node_after, dropout_node]
if _is_annotated(partition):
continue
if filter_fn and any(not filter_fn(n) for n in partition):
continue
dropout_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
_annotated=True,
)
node_after.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=output_act_qspec,
_annotated=True,
)
_mark_nodes_as_annotated(partition)
annotated_partitions.append(partition)
return annotated_partitions
@register_annotator("sima_batchnorm")
def _annotate_batchnorm(
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
annotated_partitions = []
for n in gm.graph.nodes:
if n.op != "call_function" or n.target not in [
operator.getitem
]:
continue
getitem_node = n
maybe_batchnorm_node = n.args[0]
if (
not isinstance(maybe_batchnorm_node, Node)
or maybe_batchnorm_node.op != "call_function"
or maybe_batchnorm_node.target
not in [
torch.ops.aten._native_batch_norm_legit.default
]
):
continue
batchnorm_node = maybe_batchnorm_node
input_qspec_map = {}
input_act = batchnorm_node.args[0]
assert isinstance(input_act, Node)
input_qspec_map[input_act] = get_input_act_qspec(quantization_config)
partition = [getitem_node, batchnorm_node]
if _is_annotated(partition):
continue
if filter_fn and any(not filter_fn(n) for n in partition):
continue
batchnorm_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
_annotated=True,
)
getitem_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=get_output_act_qspec(quantization_config),
_annotated=True,
)
_mark_nodes_as_annotated(partition)
annotated_partitions.append(partition)
return annotated_partitions
@register_annotator("sima_cat")
def _sima_annotate_cat(
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
cat_partitions = get_source_partitions(gm.graph, [torch.cat], filter_fn)
cat_partitions = list(itertools.chain.from_iterable(cat_partitions.values()))
annotated_partitions = []
for cat_partition in cat_partitions:
cat_node = cat_partition.output_nodes[0]
if _is_annotated([cat_node]):
continue
if cat_node.target != torch.ops.aten.cat.default:
# TODO: change this to AnnotationException
raise Exception(
f"Expected cat node: torch.ops.aten.cat.default, but found {cat_node.target}"
" please check if you are calling the correct capture API"
)
annotated_partitions.append(cat_partition.nodes)
input_act_qspec = get_input_act_qspec(quantization_config)
inputs = cat_node.args[0]
input_qspec_map = {}
for input_act in inputs:
if _is_annotated([input_act]):
continue
input_qspec_map[input_act] = input_act_qspec
output_act_qspec = get_output_act_qspec(quantization_config)
cat_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=output_act_qspec,
_annotated=True,
)
return annotated_partitions