load_model
As the first step of the PTQ process, the afe.apis.loaded_net.load_model()
API loads the models and prepares them for quantization. The required parameters for load_model
depend
on the type of model being loaded.
For example, when loading an ONNX model, both input names and input shapes must be specified, but they serve different purposes. Input names are used to identify inputs within the ONNX model and can be directly obtained by examining the model. Input shapes, however, may not always be defined in the ONNX model but are essential for the ModelSDK to process the model correctly. If input shapes are unspecified, they must be explicitly provided based on the expected input dimensions.
Understanding these distinctions is crucial for correctly configuring load_model
and diagnosing potential errors during the loading process.
For ONNX models, it is essential to specify the input names, corresponding shapes, and data types. These attributes are organized using dictionaries, where each input name is mapped to its associated shape and data type.
The input layout is assumed to follow the NCHW
format.
The following example demonstrates how to load a floating-point ONNX model with an input named input_1
, which has a shape of 1,3,224,224
in the NCHW
layout.
# imports
from afe.load.importers.general_importer import ImporterParams, onnx_source
from afe.apis.loaded_net import load_model
from afe.ir.tensor_type import ScalarType
# model path
onnx_model_path='../path/to/onnx/model/file'
# input shapes dictionary: each key,value pair defines an input and its shape
input_shapes = {'input_1': (1,3,224,224)}
# input types dictionary: each key,value pair defines an input and its type
input_types = {'input_1': ScalarType.float32}
# importer parameters
importer_params: ImporterParams = onnx_source(model_path=onnx_model_path,
shape_dict=input_shapes,
dtype_dict=input_types)
# load ONNX floating-point model into LoadedNet format
loaded_net = load_model(importer_params)
For native PyTorch models, it is necessary to specify the input names and their corresponding shapes. Input names are defined as a list of strings, while input shapes are specified as a list of tuples containing integers. The association between input names and shapes is determined by their respective indices in the lists.
By default, the input layout follows the NCHW
format.
The following example demonstrates how to load a floating-point PyTorch model with an input named 1
and a shape of 1,3,224,224
in the NCHW
layout:
# imports
from afe.load.importers.general_importer import ImporterParams, pytorch_source
from afe.apis.loaded_net import load_model
# model path
pyt_model_path='../path/to/pytorch/model/file'
# model input names - list of strings
input_names = ['1']
# input shapes - list of tuples
# assumed to be NCHW format - batchsize is required
input_shape = [(1,3,224,224)]
# importer parameters - layout assumed to be NCHW
importer_params: ImporterParams = pytorch_source(model_path=pyt_model_path,
input_names=input_names,
input_shapes=input_shapes)
# load PyTorch floating-point model into LoadedNet format
loaded_net = load_model(importer_params)
For Keras models in HDF5 format, it is necessary to specify a dictionary that maps model inputs to their corresponding shapes. By default, the layout
parameter is
set to NHWC
but can be configured to NCHW
if required.
The following example demonstrates how to load a floating-point Keras model with an input named input_1
, having a shape of 1,224,224,3
in the NHWC
layout:
# imports
from afe.load.importers.general_importer import ImporterParams, keras_source
from afe.apis.loaded_net import load_model
# model path
keras_model_path='../path/to/keras/model/file'
# input shapes dictionary: each key,value pair defines an input and its shape
input_shapes_dict = {'input_1': (1,224,224,3)}
# importer parameters - layout specified as NHWC
importer_params: ImporterParams = keras_source(model_path=keras_model_path,
shape_dict=input_shapes_dict,
layout='NHWC')
# load Keras floating-point model into LoadedNet format
loaded_net = load_model(importer_params)
For TensorFlow 1 models in protobuf format, it is necessary to specify a dictionary that maps model inputs to their corresponding shapes and provide a list of output names. By default, the input layout is set to NHWC
.
The following example demonstrates how to load a floating-point TensorFlow 1 model with an input named input_1
having a shape of 1,224,224,3
in the NHWC
layout, and an output named prediction
:
# imports
from afe.load.importers.general_importer import ImporterParams, tensorflow_source
from afe.apis.loaded_net import load_model
# model path
tf_model_path='../path/to/tensorflow/model/file'
# input shapes dictionary: each key,value pair defines an input and its shape
input_shapes = {'input_1': (1,224,224,3)}
# output names - list of strings
output_names = ['prediction']
# importer parameters
importer_params: ImporterParams = tensorflow_source(model_path=tf_model_path,
input_shapes=input_shapes,
output_names=output_names)
# load Tensorflow floating-point model into LoadedNet format
loaded_net = load_model(importer_params)
For TFLite models, input names are mapped to their corresponding shapes and data types using dictionaries.
By default, the input layout is set to NHWC
.
The following example demonstrates how to load a floating-point TFLite model with an input named input_1
, having a shape of 1,224,224,3
in the NHWC
layout:
# imports
from afe.load.importers.general_importer import ImporterParams, tflite_source
from afe.apis.loaded_net import load_model
from afe.ir.tensor_type import ScalarType
# model path
tflite_model_path='../path/to/tflite/model/file'
# input shapes dictionary: each key,value pair defines an input and its shape
input_shapes = {'input_1': (1,224,224,3)}
# input types dictionary: each key,value pair defines an input and its type
input_types = {'input_1': ScalarType.float32}
# importer parameters
importer_params: ImporterParams = tflite_source(model_path=tflite_model_path,
shape_dict=input_shapes,
dtype_dict=input_types)
# load TFLite floating-point model into LoadedNet format
loaded_net = load_model(importer_params)
The TFLite models that have been quantized using the TFLite quantization can also be imported into ModelSDK using exactly the same code that was used for importing the floating-point TFLite model but with one exception - the input_types dictionary will need to indicate the correct data type(int8, uint8):
# input types dictionary: each key-value pair defines an input and its type
input_types = {'input_1': ScalarType.uint8}