#########################################################
# Copyright (C) 2020 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Jeffrey Spitz
#########################################################
import os
from typing import Tuple, Dict
from afe._tvm._defines import TVMIRModule, TVMNDArray
from afe.ir.tensor_type import ScalarType
[docs]
def get_caffe2_predictor(init_net_file_path: str,
predict_net_file_path: str):
"""
Returns a predictor object we can use to execute the caffe2 model
:param init_net_file_path: filepath to the caffe2 .pb init_net file
:param predict_net_file_path: filepath to the caffe2 .pb predict_net file
"""
from caffe2.python import workspace
with open(init_net_file_path, 'rb') as f:
init_net = f.read()
with open(predict_net_file_path, 'rb') as f:
predict_net = f.read()
predictor = workspace.Predictor(init_net, predict_net)
return predictor
[docs]
def import_caffe2_to_tvm(init_net_file_path: str,
predict_net_file_path: str,
shape_dict: Dict[str, Tuple[int, ...]],
dtype_dict: Dict[str, ScalarType]) -> TVMIRModule:
"""
Use TVM frontend to import a caffe2 pb model into TVM Relay IR
:param init_net_file_path: filepath to the caffe2 .pb init_net file
:param predict_net_file_path: filepath to the caffe2 .pb predict_net file
:param shape_dict: dictionary of input names to input shapes (eg. (1,224,224,3))
:param dtype_dict: dictionary of input names to input types (eg. float32 or int64)
:return: Imported TVM IR module.
"""
from caffe2.proto import caffe2_pb2
from afe._tvm._importers._caffe2_importer import _caffe2_to_tvm_ir
validate_input_parameters(init_net_file_path, predict_net_file_path, shape_dict, dtype_dict)
# Load the graph
c2_init_net = caffe2_pb2.NetDef()
with open(init_net_file_path, 'rb') as f:
c2_init_net.ParseFromString(f.read())
c2_predict_net = caffe2_pb2.NetDef()
with open(predict_net_file_path, 'rb') as f:
c2_predict_net.ParseFromString(f.read())
return _caffe2_to_tvm_ir(init_net=c2_init_net,
predict_net=c2_predict_net,
shape_dict=shape_dict,
dtype_dict=dtype_dict)