Loading Floating-Point Models
This section describes how to load/import supported floating-point models in the ModelSDK. The floating-point model is imported into ModelSDK using the load_model
API. The load_model
API requires parameters to be provided and these parameters are specific to the type of model being loaded.
ONNX Format Models
For ONNX models, we need to know the input names, their shapes and their data types. The input names are associated with their shapes and data types using dictionaries.
The input shape dictionary is made up of key, value pairs in which the key is an input name (string) and the value is a tuple of integers defining the shape. The input types dictionary is made up of key, value pairs in which the key is an input name (string) and the value is a ScalarType (generally ScalarType.float32).
The layout is assumed to be NCHW
.
The example below shows how to load a floating-point ONNX model that has an input named input_1
with a shape of 1,3,224,224
in NCHW
format:
# imports
from afe.load.importers.general_importer import ImporterParams, onnx_source
from afe.apis.loaded_net import load_model
from afe.ir.tensor_type import ScalarType
# model path
onnx_model_path='../path/to/onnx/model/file'
# input shapes dictionary: each key,value pair defines an input and its shape
input_shapes = {'input_1': (1,3,224,224)}
# input types dictionary: each key,value pair defines an input and its type
input_types = {'input_1': ScalarType.float32}
# importer parameters
importer_params: ImporterParams = onnx_source(model_path=onnx_model_path,
shape_dict=input_shapes,
dtype_dict=input_types)
# load ONNX floating-point model into LoadedNet format
loaded_net = load_model(importer_params)
PyTorch Format Models
For native PyTorch models, we must specify the input names and their shapes. The input names are specified as a list of strings and the input shapes are defined as a list of tuples of integers. The association between input names and input shapes is based on the list indices. Layout defaults to NCHW
.
This example shows how to load a floating-point PyTorch model that has an input named 1
with a shape of 1,3,224,224
in NCHW
format:
# imports
from afe.load.importers.general_importer import ImporterParams, pytorch_source
from afe.apis.loaded_net import load_model
# model path
pyt_model_path='../path/to/pytorch/model/file'
# model input names - list of strings
input_names = ['1']
# input shapes - list of tuples
# assumed to be NCHW format - batchsize is required
input_shape = [(1,3,224,224)]
# importer parameters - layout assumed to be NCHW
importer_params: ImporterParams = pytorch_source(model_path=pyt_model_path,
input_names=input_names,
input_shapes=input_shapes)
# load PyTorch floating-point model into LoadedNet format
loaded_net = load_model(importer_params)
Keras Format Models
For the Keras models in HDF5 format, we must specify a dictionary that associates the model inputs with their shapes. The layout
parameter defaults to NHWC
but can be changed to NCHW
using the layout parameter.
This example shows how to load a floating-point Keras model that has an input named input_1
with a shape of 1,224,224,3
in NHWC
format:
# imports
from afe.load.importers.general_importer import ImporterParams, keras_source
from afe.apis.loaded_net import load_model
# model path
keras_model_path='../path/to/keras/model/file'
# input shapes dictionary: each key,value pair defines an input and its shape
input_shapes_dict = {'input_1': (1,224,224,3)}
# importer parameters - layout specified as NHWC
importer_params: ImporterParams = keras_source(model_path=keras_model_path,
shape_dict=input_shapes_dict,
layout='NHWC')
# load Keras floating-point model into LoadedNet format
loaded_net = load_model(importer_params)
Tensorflow1 Format Models
For the Tensorflow1 models in protobuf format, we must specify a dictionary that associates the model inputs with their shapes and provide a list of the output names. Layout defaults to NHWC
.
The example below shows how to load a floating-point Tensorflow1 model that has an input named input_1
with a shape of 1,224,224,3
in NHWC
format and an output named prediction:
# imports
from afe.load.importers.general_importer import ImporterParams, tensorflow_source
from afe.apis.loaded_net import load_model
# model path
tf_model_path='../path/to/tensorflow/model/file'
# input shapes dictionary: each key,value pair defines an input and its shape
input_shapes = {'input_1': (1,224,224,3)}
# output names - list of strings
output_names = ['prediction']
# importer parameters
importer_params: ImporterParams = tensorflow_source(model_path=tf_model_path,
input_shapes=input_shapes,
output_names=output_names)
# load Tensorflow floating-point model into LoadedNet format
loaded_net = load_model(importer_params)
TFLite Format Models
For the TFLite models, the input names are associated with their shapes and data types using dictionaries.
The input shape dictionary is made up of key, value pairs in which the key is an input name (string) and the value is a tuple of integers defining the shape.
The input types dictionary is made up of key, value pairs in which the key is an input name (string) and the value is a ScalarType
(generally ScalarType.float32
).
Layout defaults to NHWC
.
The example below shows how to load a floating-point TFLite model that has an input named input_1
with a shape of 1,224,224,3
in NHWC
format:
# imports
from afe.load.importers.general_importer import ImporterParams, tflite_source
from afe.apis.loaded_net import load_model
from afe.ir.tensor_type import ScalarType
# model path
tflite_model_path='../path/to/tflite/model/file'
# input shapes dictionary: each key,value pair defines an input and its shape
input_shapes = {'input_1': (1,224,224,3)}
# input types dictionary: each key,value pair defines an input and its type
input_types = {'input_1': ScalarType.float32}
# importer parameters
importer_params: ImporterParams = tflite_source(model_path=tflite_model_path,
shape_dict=input_shapes,
dtype_dict=input_types)
# load TFLite floating-point model into LoadedNet format
loaded_net = load_model(importer_params)
Pre-Quantized TFLite Models
The TFLite models that have been quantized using the TFLite quantization can also be imported into ModelSDK using exactly the same code that was used for importing the floating-point TFLite model but with one exception - the input_types dictionary will need to indicate the correct data type(int8, uint8):
# input types dictionary: each key-value pair defines an input and its type
input_types = {'input_1': ScalarType.uint8}