Source code for afe.load.importers.caffe2

#########################################################
# Copyright (C) 2020 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Jeffrey Spitz
#########################################################
import os
from typing import Tuple, Dict
from afe._tvm._defines import TVMIRModule, TVMNDArray
from afe.ir.tensor_type import ScalarType


[docs] def get_caffe2_predictor(init_net_file_path: str, predict_net_file_path: str): """ Returns a predictor object we can use to execute the caffe2 model :param init_net_file_path: filepath to the caffe2 .pb init_net file :param predict_net_file_path: filepath to the caffe2 .pb predict_net file """ from caffe2.python import workspace with open(init_net_file_path, 'rb') as f: init_net = f.read() with open(predict_net_file_path, 'rb') as f: predict_net = f.read() predictor = workspace.Predictor(init_net, predict_net) return predictor
[docs] def validate_input_parameters(init_net_file_path: str, predict_net_file_path: str, shape_dict: Dict[str, Tuple[int, ...]], dtype_dict: Dict[str, ScalarType]) -> None: """ Validates the user supplied input. :param init_net_file_path: filepath to the caffe2 .pb init_net file :param predict_net_file_path: filepath to the caffe2 .pb predict_net file :param shape_dict: dictionary of input names to input shapes (eg. (1,224,224,3)) :param dtype_dict: dictionary of input names to input types (eg. float32 or int64) """ assert '.pb' in init_net_file_path, "Error: We expect a .pb file to be supplied for a caffe2 init_net import" assert '.pb' in predict_net_file_path, "Error: We expect a .pb file to be supplied for a caffe2 predict_net import" assert os.path.exists(init_net_file_path), f"caffe2 .pb init_net file cannot be found at {init_net_file_path}" assert os.path.exists(predict_net_file_path), f"caffe2 .pb predict_net file cannot be found at {predict_net_file_path}" assert shape_dict, "Error: Please supply a shape dictionary in the form of input names to input shapes" assert dtype_dict, "Error: Please supply an dtype dictionary in the form of input names to input types"
[docs] def import_caffe2_to_tvm(init_net_file_path: str, predict_net_file_path: str, shape_dict: Dict[str, Tuple[int, ...]], dtype_dict: Dict[str, ScalarType]) -> TVMIRModule: """ Use TVM frontend to import a caffe2 pb model into TVM Relay IR :param init_net_file_path: filepath to the caffe2 .pb init_net file :param predict_net_file_path: filepath to the caffe2 .pb predict_net file :param shape_dict: dictionary of input names to input shapes (eg. (1,224,224,3)) :param dtype_dict: dictionary of input names to input types (eg. float32 or int64) :return: Imported TVM IR module. """ from caffe2.proto import caffe2_pb2 from afe._tvm._importers._caffe2_importer import _caffe2_to_tvm_ir validate_input_parameters(init_net_file_path, predict_net_file_path, shape_dict, dtype_dict) # Load the graph c2_init_net = caffe2_pb2.NetDef() with open(init_net_file_path, 'rb') as f: c2_init_net.ParseFromString(f.read()) c2_predict_net = caffe2_pb2.NetDef() with open(predict_net_file_path, 'rb') as f: c2_predict_net.ParseFromString(f.read()) return _caffe2_to_tvm_ir(init_net=c2_init_net, predict_net=c2_predict_net, shape_dict=shape_dict, dtype_dict=dtype_dict)