sima_qat.qat_api

Attributes

device_modifier_ops

Classes

SimaQatWrapper

This is a Sima-defined wrapper which allows Pytorch GraphModule objects to behave

Functions

sima_prepare_qat_model(β†’Β torch.fx.graph_module.GraphModule)

This function is the first transformation needed to perform QAT on a Pytorch model. It takes an

sima_finalize_qat_model(...)

This function takes a QAT scaffolded model which has completed the training regimen and

sima_export_onnx(β†’Β torch.fx.graph_module.GraphModule)

This function exports a finalized QAT model to ONNX format.

check_graph_nodes(β†’Β torch.fx.graph_module.GraphModule)

Checks the prepared model for inconsistent device paramterers and

replace_dropout(β†’Β torch.fx.graph_module.GraphModule)

replace_batchnorm(β†’Β torch.fx.graph_module.GraphModule)

FX Graph rewriter to replace a flavor of batchnorm with one that can be exported

Module Contents

sima_qat.qat_api.device_modifier_ops[source]
sima_qat.qat_api.sima_prepare_qat_model(input_graph: torch.nn.Module, inputs: Tuple, device: torch.device) torch.fx.graph_module.GraphModule[source]

This function is the first transformation needed to perform QAT on a Pytorch model. It takes an eager-mode reference to the ML model and produces an FX version of the graph with special annotations needed for QAT. Internally, it will scaffold the graph using observers and fakequant nodes needed during the training process.

Note

The Pytorch graph on which QAT is performed may be a full model, or a subsection of a model. QAT optimization will be limited to the graph given by the input_graph argument. This region must always be contained to the level of hierarchy as described by a single nn.Module.

Parameters:
  • input_graph – an eager-mode nn.Module representing the model on which QAT is to be performed. This may be a full model, or may be a sub-section of an ML model.

  • inputs – a Tuple of tensor inputs, sized to the correct shape as the input to the given input_graph. This data can be randomly generated. It is used during the preparation process to build the compiled FX representation.

  • device – a Pytorch device identifier. This will be the device on which the prepared model will be located after the preparation step is complete.

Returns:

a compiled version of the given graph with QAT annotations, ready to begin training.

Return type:

GraphModule

sima_qat.qat_api.sima_finalize_qat_model(qat_model: torch.fx.graph_module.GraphModule) torch.fx.graph_module.GraphModule[source]

This function takes a QAT scaffolded model which has completed the training regimen and converts it to an inference-only (via fakequant) form. Once this process is complete, the model can no longer be trained, and is intended for export via ONNX.

Parameters:

qat_model – a trained QAT model to be converted into inference-only form.

Returns:

an inference-only version of the QAT model, which can be run in Pytorch

eval(True) mode, or exported via ONNX.

Return type:

GraphModule

sima_qat.qat_api.sima_export_onnx(qat_model: torch.nn.Module, inputs: Tuple[torch.Tensor], output_file: str, input_names: List[str] | None = None, output_names: List[str] | None = None, device: torch.device = 'cuda') torch.fx.graph_module.GraphModule[source]

This function exports a finalized QAT model to ONNX format.

Parameters:
  • qat_model – The finalized ML model to export to ONNX.

  • inputs – a Tuple of tensor inputs used to infer the proper shapes of all internal tensors. This is used by the Pytorch ONNX exporter.

  • output_file – the path name of the .onnx file to generate.

  • input_names – a list of tensor names used to label the ONNX model inputs.

  • output_names – a list of tensor names used to label the ONNX model outputs.

class sima_qat.qat_api.SimaQatWrapper(source: torch.fx.graph_module.GraphModule, label: str)[source]

This is a Sima-defined wrapper which allows Pytorch GraphModule objects to behave like `nn.Module`s at training time. It is used so that commonly called Pytorch functions work correctly when QAT is invoked.

Note

This wrapper can only be created from an existing GraphModule. The source GraphModules are created by Pytorch at each control point during QAT runtime.

train(use_train: bool = True) SimaQatWrapper[source]

This function emulates the behavior of train() on nn.Module.

Parameters:

use_train – set to True is training mode is desired, False if evaluation mode.

Returns:

a copy of the self variable. This return value provides

consistency with the behavior of nn.Module.train().

Return type:

SimaQatWrapper

eval(use_eval: bool = True) SimaQatWrapper[source]

This function sets the QAT module to evaluate mode. It is equivalent to nn.Module.eval().

Parameters:

use_eval – set to True if evaluation mode is desired, False if training mode.

Returns:

a copy of the self variable. This return value provides

consistency with the behavior of nn.Module.eval().

Return type:

SimaQatWrapper

load_state_dict(state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False)[source]

This specialized load function adds an additional check for QAT models. This check ensures that a loaded state corresponds to the same phase of QAT training as the model skeleton in memory.

Parameters:
  • state_dict – a mapping of string keys to learned state (dense tensor data).

  • strict – if True, all keys in the state_dict must exactly match the contents of this Module. If False, keys are allowed to mismatch state elements within this Module.

  • assign – When False, the properties of the tensors in the current module are preserved while when True, the properties of the Tensors in the state dict are preserved. The only exception is the requires_grad field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved.

sima_qat.qat_api.check_graph_nodes(prepared_mod: torch.fx.graph_module.GraphModule, device: torch.device) torch.fx.graph_module.GraphModule[source]

Checks the prepared model for inconsistent device paramterers and also for setting the dropout layers to inactive mode

sima_qat.qat_api.replace_dropout(m: torch.fx.graph_module.GraphModule) torch.fx.graph_module.GraphModule[source]
sima_qat.qat_api.replace_batchnorm(m: torch.fx.graph_module.GraphModule) torch.fx.graph_module.GraphModule[source]

FX Graph rewriter to replace a flavor of batchnorm with one that can be exported