from typing import Tuple, Dict, Optional, Callable, List
from afe._tvm._defines import TVMIRModule, TVMNDArray
from afe.ir.tensor_type import ScalarType
[docs]
def import_onnx_model(file_path: str):
"""
Load a ONNX model from a file
:param file_path: str. File path to the onnx .onnx file
:return: A ONNX model
"""
import onnx
model = onnx.load(file_path)
del onnx
return model
[docs]
def import_onnx_to_tvm(onnx_file_path: str,
shape_dict: Dict[str, Tuple[int, ...]],
dtype_dict: Dict[str, ScalarType],
custom_convert_map: Optional[Dict[str, Callable]] = None
) -> Tuple[TVMIRModule, List[str]]:
"""
Use TVM frontend to import a onnx model into TVM Relay IR
:param onnx_file_path: filepath to the onnx .onnx file
:param shape_dict: dictionary of input names to input shapes (eg. (1,224,224,3))
:param dtype_dict: dictionary of input names to input types (eg. float32 or int64)
:param custom_convert_map: A custom op conversion map that maps operation names to functions. Whenever an operator
with a name found in the custom_convert_map is found in TVM, the function is called with
3 arguments
inputs = a tvm onnx_input object which contains a dictionary of tvm function inputs
attr = a dictionary of operation attributes
params = a dictionary of all the constants in the onnx network
The function then returns the tvm relay IR expression that is inserted into the model
wherever the operation occurs.
:return: Imported TVM IR module and names of the ONNX model's outputs.
"""
from afe._tvm._importers._onnx_importer import _onnx_to_tvm_ir
validate_input_parameters(onnx_file_path, shape_dict, dtype_dict)
onnx_model = import_onnx_model(onnx_file_path)
ir_module = _onnx_to_tvm_ir(onnx_model, shape_dict, dtype_dict, custom_convert_map)
output_labels = [o.name for o in onnx_model.graph.output]
return ir_module, output_labels