load_model

As the first step of the PTQ process, the afe.apis.loaded_net.load_model() API loads the models and prepares them for quantization. The required parameters for load_model depend on the type of model being loaded.

For example, when loading an ONNX model, both input names and input shapes must be specified, but they serve different purposes. Input names are used to identify inputs within the ONNX model and can be directly obtained by examining the model. Input shapes, however, may not always be defined in the ONNX model but are essential for the ModelSDK to process the model correctly. If input shapes are unspecified, they must be explicitly provided based on the expected input dimensions.

Understanding these distinctions is crucial for correctly configuring load_model and diagnosing potential errors during the loading process.

For ONNX models, it is essential to specify the input names, corresponding shapes, and data types. These attributes are organized using dictionaries, where each input name is mapped to its associated shape and data type.

The input layout is assumed to follow the NCHW format.

The following example demonstrates how to load a floating-point ONNX model with an input named input_1, which has a shape of 1,3,224,224 in the NCHW layout.

# imports
from afe.load.importers.general_importer import ImporterParams, onnx_source
from afe.apis.loaded_net import load_model
from afe.ir.tensor_type import ScalarType


# model path
onnx_model_path='../path/to/onnx/model/file'

# input shapes dictionary: each key,value pair defines an input and its shape
input_shapes = {'input_1': (1,3,224,224)}

# input types dictionary: each key,value pair defines an input and its type
input_types = {'input_1': ScalarType.float32}

# importer parameters
importer_params: ImporterParams = onnx_source(model_path=onnx_model_path,
                                            shape_dict=input_shapes,
                                            dtype_dict=input_types)

# load ONNX floating-point model into LoadedNet format
loaded_net = load_model(importer_params)