sima_qat.qat_apiο
Attributesο
Classesο
This is a Sima-defined wrapper which allows Pytorch GraphModule objects to behave |
Functionsο
|
This function is the first transformation needed to perform QAT on a Pytorch model. It takes an |
This function takes a QAT scaffolded model which has completed the training regimen and |
|
|
This function exports a finalized QAT model to ONNX format. |
|
Checks the prepared model for inconsistent device paramterers and |
|
|
|
FX Graph rewriter to replace a flavor of batchnorm with one that can be exported |
Module Contentsο
- sima_qat.qat_api.sima_prepare_qat_model(input_graph: torch.nn.Module, inputs: Tuple, device: torch.device) torch.fx.graph_module.GraphModule [source]ο
This function is the first transformation needed to perform QAT on a Pytorch model. It takes an eager-mode reference to the ML model and produces an FX version of the graph with special annotations needed for QAT. Internally, it will scaffold the graph using observers and fakequant nodes needed during the training process.
Note
The Pytorch graph on which QAT is performed may be a full model, or a subsection of a model. QAT optimization will be limited to the graph given by the input_graph argument. This region must always be contained to the level of hierarchy as described by a single nn.Module.
- Parameters:
input_graph β an eager-mode nn.Module representing the model on which QAT is to be performed. This may be a full model, or may be a sub-section of an ML model.
inputs β a Tuple of tensor inputs, sized to the correct shape as the input to the given input_graph. This data can be randomly generated. It is used during the preparation process to build the compiled FX representation.
device β a Pytorch device identifier. This will be the device on which the prepared model will be located after the preparation step is complete.
- Returns:
a compiled version of the given graph with QAT annotations, ready to begin training.
- Return type:
GraphModule
- sima_qat.qat_api.sima_finalize_qat_model(qat_model: torch.fx.graph_module.GraphModule) torch.fx.graph_module.GraphModule [source]ο
This function takes a QAT scaffolded model which has completed the training regimen and converts it to an inference-only (via fakequant) form. Once this process is complete, the model can no longer be trained, and is intended for export via ONNX.
- Parameters:
qat_model β a trained QAT model to be converted into inference-only form.
- Returns:
- an inference-only version of the QAT model, which can be run in Pytorch
eval(True) mode, or exported via ONNX.
- Return type:
GraphModule
- sima_qat.qat_api.sima_export_onnx(qat_model: torch.nn.Module, inputs: Tuple[torch.Tensor], output_file: str, input_names: List[str] | None = None, output_names: List[str] | None = None, device: torch.device = 'cuda') torch.fx.graph_module.GraphModule [source]ο
This function exports a finalized QAT model to ONNX format.
- Parameters:
qat_model β The finalized ML model to export to ONNX.
inputs β a Tuple of tensor inputs used to infer the proper shapes of all internal tensors. This is used by the Pytorch ONNX exporter.
output_file β the path name of the .onnx file to generate.
input_names β a list of tensor names used to label the ONNX model inputs.
output_names β a list of tensor names used to label the ONNX model outputs.
- class sima_qat.qat_api.SimaQatWrapper(source: torch.fx.graph_module.GraphModule, label: str)[source]ο
This is a Sima-defined wrapper which allows Pytorch GraphModule objects to behave like `nn.Module`s at training time. It is used so that commonly called Pytorch functions work correctly when QAT is invoked.
Note
This wrapper can only be created from an existing GraphModule. The source GraphModules are created by Pytorch at each control point during QAT runtime.
- train(use_train: bool = True) SimaQatWrapper [source]ο
This function emulates the behavior of train() on nn.Module.
- Parameters:
use_train β set to True is training mode is desired, False if evaluation mode.
- Returns:
- a copy of the self variable. This return value provides
consistency with the behavior of nn.Module.train().
- Return type:
- eval(use_eval: bool = True) SimaQatWrapper [source]ο
This function sets the QAT module to evaluate mode. It is equivalent to nn.Module.eval().
- Parameters:
use_eval β set to True if evaluation mode is desired, False if training mode.
- Returns:
- a copy of the self variable. This return value provides
consistency with the behavior of nn.Module.eval().
- Return type:
- load_state_dict(state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False)[source]ο
This specialized load function adds an additional check for QAT models. This check ensures that a loaded state corresponds to the same phase of QAT training as the model skeleton in memory.
- Parameters:
state_dict β a mapping of string keys to learned state (dense tensor data).
strict β if True, all keys in the state_dict must exactly match the contents of this Module. If False, keys are allowed to mismatch state elements within this Module.
assign β When
False
, the properties of the tensors in the current module are preserved while whenTrue
, the properties of the Tensors in the state dict are preserved. The only exception is therequires_grad
field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved.
- sima_qat.qat_api.check_graph_nodes(prepared_mod: torch.fx.graph_module.GraphModule, device: torch.device) torch.fx.graph_module.GraphModule [source]ο
Checks the prepared model for inconsistent device paramterers and also for setting the dropout layers to inactive mode