.. _Quantization Aware Training: Quantization Aware Training ########################### SiMa.ai provides support for Quantization Aware Training (QAT) for achieving high accuracy of the machine learning models. This section describes the pre-requisites, the methods, and the outcomes from using the QAT option in the SiMa ModelSDK. QAT is considered to be the most high-effort of all approaches to map machine learning models onto special-purpose hardware. It is a machine learning technique used to tune an ML model directly to a given target device. It uses the Stochastic Gradient Descent (SGD) optimization method to modify the parameters of an ML model, such that it conforms to a new set of constraints, imposed by the target device. These new constraints are usually of the form of number encodings (e.g. FP32->INT8), but may optionally include constraints like memory size and approximations of non-linear functions. Pre-requisites ************** * The ModelSDK uses PyTorch, a major ML framework, to implement QAT. * A high-quality source (original training data) repository which has been vetted to train with high accuracy. QAT Process *********** .. image:: media/qatprocessdiagram.png :alt: QAT Process Diagram :scale: 80% :align: center The process of QAT uses the same basic training loop for any ML model. In fact, this process is best done using the original training environment of the ML model being optimized. This is due to the following special considerations: #. Training almost always uses special augmentations on the training data. This is a common technique, and when papers publish their best FP32 numbers, researchers will usually avail themselves of training time augmentations. #. The loss function for many networks can be rather complex. The yolo family of models is a good example of this. YoloV8 uses a custom loss function which combines 3 different loss measurements, and this technique substantially improves the network accuracy over the simple loss function used by early generation Yolo networks. #. In order to reach the same level of generality of the FP32 network, the QAT-version of the network should be optimized over the same set of data samples as the fully trained version. This requires use of the original training set. QAT Workflow ************ .. image:: media/qatuserworkflow.png :alt: QAT User Workflow :scale: 60% :align: center The SiMa QAT workflow is similar to the one commonly used in the industry for training machine learning models. We will begin with a model in PyTorch format at the model definition stage. We use PyTorch mainly because the models exported to ONNX lack some of the important layers used to stabilize the training process (notably BatchNorm). These layers are important for training a model and can have a significant impact on the training quality. To quickly achieve high-accuracy numbers although not required, we begin the training process with a pre-trained FP32 model. If you do not have pre-trained results, you can start from a random initialization which will eventually converge to a good result during the QAT training. This will take much longer than having a pre-trained model, but it is an option when a pre-trained model is not available. Model Training Process ********************** The QAT training process is much like training a simple FP32 model. There is one major difference during the forward evaluation process. At the output of each operator in the graph, PyTorch inserts Quantize/Dequantize operators to modify the activation tensor. This simulates a lower precision evaluation of the graph, with some inaccuracies. The backward pass to compute gradients and the weight update scheme is unmodified, consistent with nearly all other published QAT methods. This scheme is implemented by PyTorch as the fundamental mechanism of quantization-awareness. The SiMa QAT process uses weights in per-channel INT8 encoding, and activations to per-tensor UINT8 encoding. The mixed-precision support will be implemented in a future release. Scaffolding =========== During training, extra operators are added to the graph, both to compute the QDQ (Quantize and DeQuantize) operations as well as observers to calibrate scale and offset of activation tensors. These extra operators are commonly called “scaffolding”. SiMa will use that term here when describing a QAT annotated ML model. Finalization ============ PyTorch uses a separate API call at the conclusion of model training, where the graph is substantially transformed away from the scaffolded form. The result of this API call will be the removal of all observers, and the instantiation of quantize/dequantize nodes. The model can be run in inference mode only after this point in PyTorch, and can be exported via ONNX. Export ====== The finalized graph can be exported by using the provided `torch.export` function; model export will generate a legal ONNX model which conforms to the ONNX standard (with no extra extensions). This model will contain QDQ annotations which can be consumed by the ModelSDK during the compile process. Training Loop ============= SiMa requires the user to implement the core training loop. This is important, because the training loop may have arbitrarily complicated pre/post processing with a loss function calculation within. Since it is not possible to handle this necessity in a generic way, we provide APIs that should be called by the training loop code. PyTorch Integration =================== SiMa uses the PyTorch PT2EQ capabilities to provide QAT functionality to users. Users do not need to invoke Pytorch’s API directly; SiMa provides this feature transparently to the user. The user should ensure that the PyTorch code to be retained using QAT is well-behaved using Pytorch’s existing API `torch.compile`. Pytorch supports QAT via the FX compiled code mechanism. Users should take steps to ensure that candidate models for QAT are well-behaved and arithmetically correct when `torch.compile` is applied. This is a mandatory pre-requisite to using SiMa’s QAT capability. User API ======== SiMa uses a small set of APIs to properly implement QAT on a user model. .. list-table:: QAT User API Functions :widths: 75 25 :header-rows: 1 :align: left * - API Function - Description * - sima_prepare_qat_model .. code-block:: python sima_prepare_qat_model(input_graph: nn.Module, inputs: Tuple, device: torch.device) -> torch.nn.Module - Takes any PyTorch nn.Module and prepares it for QAT training. * - sima_finalize_qat_model .. code-block:: python sima_finalize_qat_mode(qat_model: torch.nn.Module) -> torch.nn.Module - Takes a trained QAT model and finalizes it. It becomes inference-only after this point. * - sima_export_onnx .. code-block:: python sima_export_onnx(qat_model: nn.Module, inputs: Tuple[Tensor], output_file: str, input_names: Optional[List[str]] = None, output_names: Optional[List[str]] = None) -> None - Takes the finalized QAT model and exports an ONNX graph for the same.