#########################################################
# Copyright (C) 2022 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Nenad Nikolic
#########################################################
import os.path
from typing import Tuple, List, Optional
from afe._tvm._defines import TVMIRModule, TVMNDArray
from afe.core.configs import ModelConfigs
[docs]
def import_model_from_model_configs(config: ModelConfigs) -> Tuple[TVMIRModule, Optional[List[str]]]:
"""
Importing a model written in any of the supported libraries from a file,
and returning it as TVM IR module.
:param config: ModelConfigs of the model being loaded. Contains information like path, framework and other.
:return: Tuple of the model converted to TVM IR module and the model's output names.
"""
framework = config.framework
assert len(config.input_names), "Model must have at least one input tensor"
assert all(name != "" for name in config.input_names), "Input tensor name cannot be an empty string"
if config.output_names is not None:
assert len(config.output_names), "Model must have at least one output tensor"
assert all(name != "" for name in config.output_names), "Output tensor name cannot be an empty string"
if framework == 'tensorflow':
from afe.load.importers.tensorflow import import_tensorflow_pb_to_tvm
return import_tensorflow_pb_to_tvm(config.model_path, config.shape_dict, config.output_names,
config.layout), None
elif framework == 'tensorflow2':
from afe.load.importers.tensorflow import import_tensorflow2_pb_to_tvm
return import_tensorflow2_pb_to_tvm(config.model_path, config.shape_dict, config.output_names,
config.layout), None
elif framework == 'pytorch':
from afe.load.importers.pytorch import import_pytorch_to_tvm
return import_pytorch_to_tvm(config.model_path, config.input_names, config.input_shapes,
config.input_dtypes), None
elif framework == 'onnx':
from afe.load.importers.onnx import import_onnx_to_tvm
from afe.tvm_converter.custom_convert_maps import CUSTOM_CONVERT_MAP_DICT
custom_convert_map = CUSTOM_CONVERT_MAP_DICT["ONNX"]
return import_onnx_to_tvm(config.model_path, config.shape_dict, config.dtype_dict, custom_convert_map)
elif framework == 'tflite':
from afe.load.importers.tflite import import_tflite_to_tvm
return import_tflite_to_tvm(config.model_path, config.shape_dict, config.dtype_dict), None
elif framework == 'keras':
from afe.load.importers.keras import import_keras_to_tvm
return import_keras_to_tvm(config.model_path, config.shape_dict, config.layout), None
elif framework == 'caffe':
from afe.load.importers.caffe import import_caffe_to_tvm
return import_caffe_to_tvm(config.model_file_paths[0], config.model_file_paths[1],
config.shape_dict, config.dtype_dict), None
elif framework == 'caffe2':
from afe.load.importers.caffe2 import import_caffe2_to_tvm
return import_caffe2_to_tvm(config.model_file_paths[0], config.model_file_paths[1], config.shape_dict,
config.dtype_dict), None
else:
raise KeyError(f"Error: Unrecognized framework ({framework}) when trying to load model")