#########################################################
# Copyright (C) 2020 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Jeffrey Spitz
#########################################################
import os
from typing import Tuple, Dict
from afe._tvm._defines import TVMIRModule, TVMNDArray
from afe.ir.tensor_type import ScalarType
[docs]
def import_caffe_model(prototxt_file_path: str,
caffemodel_file_path: str):
"""
Loads in a caffe model from a directory path
:param prototxt_file_path: filepath to the caffe .prototxt file
:param caffemodel_file_path: filepath to the caffe .caffemodel file
"""
import caffe
net = caffe.Net(prototxt_file_path, caffemodel_file_path, caffe.TEST)
return net
[docs]
def import_caffe_to_tvm(prototxt_file_path: str,
caffemodel_file_path: str,
shape_dict: Dict[str, Tuple[int, ...]],
dtype_dict: Dict[str, ScalarType]) -> TVMIRModule:
"""
Use TVM frontend to import a caffe pb model into TVM Relay IR
:param prototxt_file_path: filepath to the caffe .prototxt file
:param caffemodel_file_path: filepath to the caffe .caffemodel file
:param shape_dict: dictionary of input names to input shapes (eg. (1,224,224,3))
:param dtype_dict: dictionary of input names to input types (eg. float32 or int64)
:return: Imported TVM IR module.
"""
from caffe.proto import caffe_pb2 as pb
from google.protobuf import text_format
from afe._tvm._importers._caffe_importer import _caffe_to_tvm_ir
validate_input_parameters(prototxt_file_path, caffemodel_file_path, shape_dict, dtype_dict)
# Load the graph
init_net = pb.NetParameter()
predict_net = pb.NetParameter()
# load model
with open(prototxt_file_path, "r") as f:
text_format.Merge(f.read(), predict_net)
# load blob
with open(caffemodel_file_path, "rb") as f:
init_net.ParseFromString(f.read())
return _caffe_to_tvm_ir(init_net, predict_net, shape_dict, dtype_dict)