sima_qat.qat_api ================ .. py:module:: sima_qat.qat_api Attributes ---------- .. autoapisummary:: sima_qat.qat_api.device_modifier_ops Classes ------- .. autoapisummary:: sima_qat.qat_api.SimaQatWrapper Functions --------- .. autoapisummary:: sima_qat.qat_api.sima_prepare_qat_model sima_qat.qat_api.sima_finalize_qat_model sima_qat.qat_api.sima_export_onnx sima_qat.qat_api.check_graph_nodes sima_qat.qat_api.replace_dropout sima_qat.qat_api.replace_batchnorm Module Contents --------------- .. py:data:: device_modifier_ops .. py:function:: sima_prepare_qat_model(input_graph: torch.nn.Module, inputs: Tuple, device: torch.device) -> torch.fx.graph_module.GraphModule This function is the first transformation needed to perform QAT on a Pytorch model. It takes an eager-mode reference to the ML model and produces an FX version of the graph with special annotations needed for QAT. Internally, it will scaffold the graph using observers and fakequant nodes needed during the training process. .. note:: The Pytorch graph on which QAT is performed may be a full model, or a subsection of a model. QAT optimization will be limited to the graph given by the `input_graph` argument. This region must always be contained to the level of hierarchy as described by a single nn.Module. :param input_graph: an eager-mode `nn.Module` representing the model on which QAT is to be performed. This may be a full model, or may be a sub-section of an ML model. :param inputs: a `Tuple` of tensor inputs, sized to the correct shape as the input to the given `input_graph`. This data can be randomly generated. It is used during the preparation process to build the compiled FX representation. :param device: a Pytorch `device` identifier. This will be the device on which the prepared model will be located after the preparation step is complete. :returns: a compiled version of the given graph with QAT annotations, ready to begin training. :rtype: GraphModule .. py:function:: sima_finalize_qat_model(qat_model: torch.fx.graph_module.GraphModule) -> torch.fx.graph_module.GraphModule This function takes a QAT scaffolded model which has completed the training regimen and converts it to an inference-only (via fakequant) form. Once this process is complete, the model can no longer be trained, and is intended for export via ONNX. :param qat_model: a trained QAT model to be converted into inference-only form. :returns: an inference-only version of the QAT model, which can be run in Pytorch `eval(True)` mode, or exported via ONNX. :rtype: GraphModule .. py:function:: sima_export_onnx(qat_model: torch.nn.Module, inputs: Tuple[torch.Tensor], output_file: str, input_names: Optional[List[str]] = None, output_names: Optional[List[str]] = None, device: torch.device = 'cuda') -> torch.fx.graph_module.GraphModule This function exports a finalized QAT model to ONNX format. :param qat_model: The finalized ML model to export to ONNX. :param inputs: a `Tuple` of tensor inputs used to infer the proper shapes of all internal tensors. This is used by the Pytorch ONNX exporter. :param output_file: the path name of the .onnx file to generate. :param input_names: a list of tensor names used to label the ONNX model inputs. :param output_names: a list of tensor names used to label the ONNX model outputs. .. py:class:: SimaQatWrapper(source: torch.fx.graph_module.GraphModule, label: str) This is a Sima-defined wrapper which allows Pytorch GraphModule objects to behave like `nn.Module`s at training time. It is used so that commonly called Pytorch functions work correctly when QAT is invoked. .. note:: This wrapper can only be created from an existing GraphModule. The source GraphModules are created by Pytorch at each control point during QAT runtime. .. py:method:: train(use_train: bool = True) -> SimaQatWrapper This function emulates the behavior of train() on nn.Module. :param use_train: set to `True` is training mode is desired, `False` if evaluation mode. :returns: a copy of the `self` variable. This return value provides consistency with the behavior of `nn.Module.train()`. :rtype: SimaQatWrapper .. py:method:: eval(use_eval: bool = True) -> SimaQatWrapper This function sets the QAT module to evaluate mode. It is equivalent to `nn.Module.eval()`. :param use_eval: set to `True` if evaluation mode is desired, `False` if training mode. :returns: a copy of the `self` variable. This return value provides consistency with the behavior of `nn.Module.eval()`. :rtype: SimaQatWrapper .. py:method:: load_state_dict(state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False) This specialized load function adds an additional check for QAT models. This check ensures that a loaded state corresponds to the same phase of QAT training as the model skeleton in memory. :param state_dict: a mapping of string keys to learned state (dense tensor data). :param strict: if `True`, all keys in the `state_dict` must exactly match the contents of this Module. If `False`, keys are allowed to mismatch state elements within this Module. :param assign: When ``False``, the properties of the tensors in the current module are preserved while when ``True``, the properties of the Tensors in the state dict are preserved. The only exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s for which the value from the module is preserved. .. py:function:: check_graph_nodes(prepared_mod: torch.fx.graph_module.GraphModule, device: torch.device) -> torch.fx.graph_module.GraphModule Checks the prepared model for inconsistent device paramterers and also for setting the dropout layers to inactive mode .. py:function:: replace_dropout(m: torch.fx.graph_module.GraphModule) -> torch.fx.graph_module.GraphModule .. py:function:: replace_batchnorm(m: torch.fx.graph_module.GraphModule) -> torch.fx.graph_module.GraphModule FX Graph rewriter to replace a flavor of batchnorm with one that can be exported