Source code for afe.core.parse_networks

#########################################################
# Copyright (C) 2022 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Nenad Nikolic
#########################################################
import os.path
from typing import Tuple, List, Optional

from afe._tvm._defines import TVMIRModule, TVMNDArray
from afe.core.configs import ModelConfigs


[docs] def import_model_from_model_configs(config: ModelConfigs) -> Tuple[TVMIRModule, Optional[List[str]]]: """ Importing a model written in any of the supported libraries from a file, and returning it as TVM IR module. :param config: ModelConfigs of the model being loaded. Contains information like path, framework and other. :return: Tuple of the model converted to TVM IR module and the model's output names. """ framework = config.framework assert len(config.input_names), "Model must have at least one input tensor" assert all(name != "" for name in config.input_names), "Input tensor name cannot be an empty string" if config.output_names is not None: assert len(config.output_names), "Model must have at least one output tensor" assert all(name != "" for name in config.output_names), "Output tensor name cannot be an empty string" if framework == 'tensorflow': from afe.load.importers.tensorflow import import_tensorflow_pb_to_tvm return import_tensorflow_pb_to_tvm(config.model_path, config.shape_dict, config.output_names, config.layout), None elif framework == 'tensorflow2': from afe.load.importers.tensorflow import import_tensorflow2_pb_to_tvm return import_tensorflow2_pb_to_tvm(config.model_path, config.shape_dict, config.output_names, config.layout), None elif framework == 'pytorch': from afe.load.importers.pytorch import import_pytorch_to_tvm return import_pytorch_to_tvm(config.model_path, config.input_names, config.input_shapes, config.input_dtypes), None elif framework == 'onnx': from afe.load.importers.onnx import import_onnx_to_tvm from afe.tvm_converter.custom_convert_maps import CUSTOM_CONVERT_MAP_DICT custom_convert_map = CUSTOM_CONVERT_MAP_DICT["ONNX"] return import_onnx_to_tvm(config.model_path, config.shape_dict, config.dtype_dict, custom_convert_map) elif framework == 'tflite': from afe.load.importers.tflite import import_tflite_to_tvm return import_tflite_to_tvm(config.model_path, config.shape_dict, config.dtype_dict), None elif framework == 'keras': from afe.load.importers.keras import import_keras_to_tvm return import_keras_to_tvm(config.model_path, config.shape_dict, config.layout), None elif framework == 'caffe': from afe.load.importers.caffe import import_caffe_to_tvm return import_caffe_to_tvm(config.model_file_paths[0], config.model_file_paths[1], config.shape_dict, config.dtype_dict), None elif framework == 'caffe2': from afe.load.importers.caffe2 import import_caffe2_to_tvm return import_caffe2_to_tvm(config.model_file_paths[0], config.model_file_paths[1], config.shape_dict, config.dtype_dict), None else: raise KeyError(f"Error: Unrecognized framework ({framework}) when trying to load model")