from typing import Tuple, Dict
from afe._tvm._defines import TVMIRModule, TVMNDArray
from afe.ir.tensor_type import ScalarType
[docs]
def import_tflite_model(file_path: str):
"""
Load a TFLite model from a tflite file
:param file_path: str. File path to the tensorflow .pb file
:return: A TFLite model
"""
import tflite
tflite_model_buf = open(file_path, 'rb').read()
# Get TFLite model from buffer
tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
del tflite_model_buf
del tflite
return tflite_model
[docs]
def import_tflite_to_tvm(tflite_file_path: str,
shape_dict: Dict[str, Tuple[int, ...]],
dtype_dict: Dict[str, ScalarType]) -> TVMIRModule:
"""
Use TVM frontend to import a tflite pb model into TVM Relay IR
:param tflite_file_path: filepath to the tensorflow .pb file
:param shape_dict: dictionary of input names to input shapes (eg. (1,224,224,3))
:param dtype_dict: dictionary of input names to input types (eg. float32 or int64)
:return: Imported TVM IR module.
"""
from afe._tvm._importers._tflite_importer import _tflite_to_tvm_ir
validate_input_parameters(tflite_file_path, shape_dict, dtype_dict)
tflite_model = import_tflite_model(tflite_file_path)
return _tflite_to_tvm_ir(tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict)