Source code for afe.load.importers.caffe

#########################################################
# Copyright (C) 2020 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Jeffrey Spitz
#########################################################
import os
from typing import Tuple, Dict
from afe._tvm._defines import TVMIRModule, TVMNDArray
from afe.ir.tensor_type import ScalarType


[docs] def import_caffe_model(prototxt_file_path: str, caffemodel_file_path: str): """ Loads in a caffe model from a directory path :param prototxt_file_path: filepath to the caffe .prototxt file :param caffemodel_file_path: filepath to the caffe .caffemodel file """ import caffe net = caffe.Net(prototxt_file_path, caffemodel_file_path, caffe.TEST) return net
[docs] def validate_input_parameters(prototxt_file_path: str, caffemodel_file_path: str, shape_dict: Dict[str, Tuple[int, ...]], dtype_dict: Dict[str, ScalarType]) -> None: """ Validates the user supplied input. :param prototxt_file_path: filepath to the caffe .prototxt file :param caffemodel_file_path: filepath to the caffe .caffemodel file :param shape_dict: dictionary of input names to input shapes (eg. (1,224,224,3)) :param dtype_dict: dictionary of input names to input types (eg. float32 or int64) """ assert '.prototxt' in prototxt_file_path, "Error: We expect a .prototxt file to be supplied for a caffe import" assert '.caffemodel' in caffemodel_file_path, "Error: We expect a .caffemodel file to be supplied for a caffe import" assert os.path.exists(prototxt_file_path), f"caffe .prototxt file cannot be found at {prototxt_file_path}" assert os.path.exists(caffemodel_file_path), f"caffe .caffemodel file cannot be found at {caffemodel_file_path}" assert shape_dict, "Error: Please supply a shape dictionary in the form of input names to input shapes" assert dtype_dict, "Error: Please supply an dtype dictionary in the form of input names to input types"
[docs] def import_caffe_to_tvm(prototxt_file_path: str, caffemodel_file_path: str, shape_dict: Dict[str, Tuple[int, ...]], dtype_dict: Dict[str, ScalarType]) -> TVMIRModule: """ Use TVM frontend to import a caffe pb model into TVM Relay IR :param prototxt_file_path: filepath to the caffe .prototxt file :param caffemodel_file_path: filepath to the caffe .caffemodel file :param shape_dict: dictionary of input names to input shapes (eg. (1,224,224,3)) :param dtype_dict: dictionary of input names to input types (eg. float32 or int64) :return: Imported TVM IR module. """ from caffe.proto import caffe_pb2 as pb from google.protobuf import text_format from afe._tvm._importers._caffe_importer import _caffe_to_tvm_ir validate_input_parameters(prototxt_file_path, caffemodel_file_path, shape_dict, dtype_dict) # Load the graph init_net = pb.NetParameter() predict_net = pb.NetParameter() # load model with open(prototxt_file_path, "r") as f: text_format.Merge(f.read(), predict_net) # load blob with open(caffemodel_file_path, "rb") as f: init_net.ParseFromString(f.read()) return _caffe_to_tvm_ir(init_net, predict_net, shape_dict, dtype_dict)