Source code for afe.load.importers.keras

from typing import List, Tuple, Dict
from afe._tvm._defines import TVMIRModule, TVMNDArray


[docs] def import_keras_model(file_path: str, input_shapes: List[Tuple[int, ...]]): """ Load a Keras model from a file path and a list of input shapes :param file_path: str. Filepath to the keras .h5 file :param input_shapes: List[Tupe[int, ...]]. List of input shapes (eg. (1,224,224,3)) return: A Keras model """ import keras model = keras.models.load_model(file_path) del keras return model
[docs] def validate_input_parameters(keras_file_path: str, shape_dict: Dict[str, Tuple[int, ...]]) -> None: """ Validates the user supplied input. :param keras_file_path: filepath to the keras .h5 file :param shape_dict: dictionary of input names to input shapes (eg. (1,224,224,3)) """ assert '.h5' in keras_file_path, "Error: We expect a .h5 file to be supplied for a Keras import" assert shape_dict, "Error: Please supply a shape dictionary in the form of input names to input shapes"
[docs] def import_keras_to_tvm(keras_file_path: str, shape_dict: Dict[str, Tuple[int, ...]], layout: str) -> TVMIRModule: """ :param keras_file_path: filepath to the keras .h5 file :param shape_dict: dictionary of input names to input shapes (eg. (1,224,224,3)) :param layout: How data should be arranged in the imported model. :return: Imported TVM IR module. """ from afe._tvm._importers._keras_importer import _keras_to_tvm_ir validate_input_parameters(keras_file_path, shape_dict) keras_model = import_keras_model(keras_file_path, list(shape_dict.values())) return _keras_to_tvm_ir(keras_model, shape_dict, layout)