.. _Model Loading: load_model ########## As the first step of the PTQ process, the :py:meth:`afe.apis.loaded_net.load_model` API loads the models and prepares them for quantization. The required parameters for ``load_model`` depend on the type of model being loaded. For example, when loading an ONNX model, both **input names** and **input shapes** must be specified, but they serve different purposes. Input names are used to identify inputs within the ONNX model and can be directly obtained by examining the model. Input shapes, however, may not always be defined in the ONNX model but are essential for the ModelSDK to process the model correctly. If input shapes are unspecified, they must be explicitly provided based on the expected input dimensions. Understanding these distinctions is crucial for correctly configuring ``load_model`` and diagnosing potential errors during the loading process. .. tabs:: .. tab:: ONNX Format Models For ONNX models, it is essential to specify the input names, corresponding shapes, and data types. These attributes are organized using dictionaries, where each input name is mapped to its associated shape and data type. The input layout is assumed to follow the ``NCHW`` format. The following example demonstrates how to load a floating-point ONNX model with an input named ``input_1``, which has a shape of ``1,3,224,224`` in the ``NCHW`` layout. .. code-block:: python # imports from afe.load.importers.general_importer import ImporterParams, onnx_source from afe.apis.loaded_net import load_model from afe.ir.tensor_type import ScalarType # model path onnx_model_path='../path/to/onnx/model/file' # input shapes dictionary: each key,value pair defines an input and its shape input_shapes = {'input_1': (1,3,224,224)} # input types dictionary: each key,value pair defines an input and its type input_types = {'input_1': ScalarType.float32} # importer parameters importer_params: ImporterParams = onnx_source(model_path=onnx_model_path, shape_dict=input_shapes, dtype_dict=input_types) # load ONNX floating-point model into LoadedNet format loaded_net = load_model(importer_params) .. tab:: PyTorch Format Models For native PyTorch models, it is necessary to specify the input names and their corresponding shapes. Input names are defined as a list of strings, while input shapes are specified as a list of tuples containing integers. The association between input names and shapes is determined by their respective indices in the lists. By default, the input layout follows the ``NCHW`` format. The following example demonstrates how to load a floating-point PyTorch model with an input named ``1`` and a shape of ``1,3,224,224`` in the ``NCHW`` layout: .. code-block:: python # imports from afe.load.importers.general_importer import ImporterParams, pytorch_source from afe.apis.loaded_net import load_model # model path pyt_model_path='../path/to/pytorch/model/file' # model input names - list of strings input_names = ['1'] # input shapes - list of tuples # assumed to be NCHW format - batchsize is required input_shape = [(1,3,224,224)] # importer parameters - layout assumed to be NCHW importer_params: ImporterParams = pytorch_source(model_path=pyt_model_path, input_names=input_names, input_shapes=input_shapes) # load PyTorch floating-point model into LoadedNet format loaded_net = load_model(importer_params) .. tab:: Keras Format Models For Keras models in HDF5 format, it is necessary to specify a dictionary that maps model inputs to their corresponding shapes. By default, the ``layout`` parameter is set to ``NHWC`` but can be configured to ``NCHW`` if required. The following example demonstrates how to load a floating-point Keras model with an input named ``input_1``, having a shape of ``1,224,224,3`` in the ``NHWC`` layout: .. code-block:: python # imports from afe.load.importers.general_importer import ImporterParams, keras_source from afe.apis.loaded_net import load_model # model path keras_model_path='../path/to/keras/model/file' # input shapes dictionary: each key,value pair defines an input and its shape input_shapes_dict = {'input_1': (1,224,224,3)} # importer parameters - layout specified as NHWC importer_params: ImporterParams = keras_source(model_path=keras_model_path, shape_dict=input_shapes_dict, layout='NHWC') # load Keras floating-point model into LoadedNet format loaded_net = load_model(importer_params) .. tab:: TensorFlow1 Format Models For TensorFlow 1 models in protobuf format, it is necessary to specify a dictionary that maps model inputs to their corresponding shapes and provide a list of output names. By default, the input layout is set to ``NHWC``. The following example demonstrates how to load a floating-point TensorFlow 1 model with an input named ``input_1`` having a shape of ``1,224,224,3`` in the ``NHWC`` layout, and an output named ``prediction``: .. code-block:: python # imports from afe.load.importers.general_importer import ImporterParams, tensorflow_source from afe.apis.loaded_net import load_model # model path tf_model_path='../path/to/tensorflow/model/file' # input shapes dictionary: each key,value pair defines an input and its shape input_shapes = {'input_1': (1,224,224,3)} # output names - list of strings output_names = ['prediction'] # importer parameters importer_params: ImporterParams = tensorflow_source(model_path=tf_model_path, input_shapes=input_shapes, output_names=output_names) # load Tensorflow floating-point model into LoadedNet format loaded_net = load_model(importer_params) .. tab:: TFLite Format Models For TFLite models, input names are mapped to their corresponding shapes and data types using dictionaries. By default, the input layout is set to ``NHWC``. The following example demonstrates how to load a floating-point TFLite model with an input named ``input_1``, having a shape of ``1,224,224,3`` in the ``NHWC`` layout: .. code-block:: python # imports from afe.load.importers.general_importer import ImporterParams, tflite_source from afe.apis.loaded_net import load_model from afe.ir.tensor_type import ScalarType # model path tflite_model_path='../path/to/tflite/model/file' # input shapes dictionary: each key,value pair defines an input and its shape input_shapes = {'input_1': (1,224,224,3)} # input types dictionary: each key,value pair defines an input and its type input_types = {'input_1': ScalarType.float32} # importer parameters importer_params: ImporterParams = tflite_source(model_path=tflite_model_path, shape_dict=input_shapes, dtype_dict=input_types) # load TFLite floating-point model into LoadedNet format loaded_net = load_model(importer_params) .. tab:: Pre-Quantized TFLite Models The TFLite models that have been quantized using the TFLite quantization can also be imported into ModelSDK using exactly the same code that was used for importing the floating-point TFLite model but with one exception - the input_types dictionary will need to indicate the correct data type(int8, uint8): .. code-block:: python # input types dictionary: each key-value pair defines an input and its type input_types = {'input_1': ScalarType.uint8}