from typing import List, Tuple, Dict
from afe._tvm._defines import TVMIRModule, TVMNDArray
[docs]
def import_keras_model(file_path: str, input_shapes: List[Tuple[int, ...]]):
"""
Load a Keras model from a file path and a list of input shapes
:param file_path: str. Filepath to the keras .h5 file
:param input_shapes: List[Tupe[int, ...]]. List of input shapes (eg. (1,224,224,3))
return: A Keras model
"""
import keras
model = keras.models.load_model(file_path)
del keras
return model
[docs]
def import_keras_to_tvm(keras_file_path: str,
shape_dict: Dict[str, Tuple[int, ...]],
layout: str) -> TVMIRModule:
"""
:param keras_file_path: filepath to the keras .h5 file
:param shape_dict: dictionary of input names to input shapes (eg. (1,224,224,3))
:param layout: How data should be arranged in the imported model.
:return: Imported TVM IR module.
"""
from afe._tvm._importers._keras_importer import _keras_to_tvm_ir
validate_input_parameters(keras_file_path, shape_dict)
keras_model = import_keras_model(keras_file_path, list(shape_dict.values()))
return _keras_to_tvm_ir(keras_model, shape_dict, layout)