Source code for afe.load.importers.pytorch

from typing import List, Tuple, Dict, Optional, Callable
from afe._tvm._defines import TVMIRModule, TVMNDArray
from afe.ir.tensor_type import ScalarType, scalar_type_to_dtype


[docs] def import_pytorch_script_module(file_path: str, input_shapes: List[Tuple[int, ...]]): """ Load a PyTorch scripted model from a file path and a list of input shapes :param file_path: str. Path to a PyTorch file (.pt) that contains the entire model :param input_shapes: List[Tuple[int, ...]]. List of input shapes corresponding to input names. eg. [(1, 224, 224, 3)] :return: A PyTorch scripted model """ import torch model = torch.load(file_path, map_location=torch.device('cpu')) model.eval() # We grab the TorchScripted model via tracing input_data = tuple([torch.randn(shape) for shape in input_shapes]) scripted_module = torch.jit.trace(model, input_data).eval() del model del torch return scripted_module
[docs] def validate_input_parameters(pt_file_path: str, input_names: List[str], input_shapes: List[Tuple[int, ...]], input_dtypes: Optional[List[str]]) -> None: """ Validates the user supplied input. :param pt_file_path: Path to a PyTorch file (.pt) that contains the entire model :param input_names: List of input names. eg ['input'] :param input_shapes: List of input shapes corresponding to input names. eg. [(1, 224, 224, 3)] :param input_dtypes:List of input datatypes corresponding to input names. eg ['float32'] """ assert '.pt' in pt_file_path, "Error: We expect a .pt file to be supplied for a PyTorch import" assert len(input_names) == len(input_shapes), "Error number of input names should equal number of input shapes" if input_dtypes: assert len(input_names) == len(input_dtypes), "Error number of input names should equal number of input dtypes"
[docs] def import_pytorch_to_tvm(pt_file_path: str, input_names: List[str], input_shapes: List[Tuple[int, ...]], input_dtypes: Optional[List[ScalarType]], custom_convert_map: Optional[Dict[str, Callable]] = None ) -> TVMIRModule: """ :param pt_file_path: Path to a PyTorch file (.pt) that contains the entire model :param input_names: List of input names. eg ['input'] :param input_shapes: List of input shapes corresponding to input names. eg. [(1, 224, 224, 3)] :param input_dtypes:List of input datatypes corresponding to input names. eg ['float32'] :param custom_convert_map: A custom op conversion map that maps operation names to functions. Whenever an operator with a name found in the custom_convert_map is found in TVM, the function is called with 2 arguments: inputs = tvm relay expression inputs to operator. input_types = list of strings indicating the input types to the operator. The function then returns the tvm relay IR expression that is inserted into the model wherever the operation occurs. :return: Imported TVM IR module. """ from afe._tvm._importers._pytorch_importer import _pytorch_to_tvm_ir validate_input_parameters(pt_file_path, input_names, input_shapes, input_dtypes) script_module = import_pytorch_script_module(pt_file_path, input_shapes) if input_dtypes is not None: input_dtypes_as_str = list() for dtype in input_dtypes: input_dtypes_as_str.append(scalar_type_to_dtype(dtype)) input_infos = [(n, (s, d)) for n, s, d in zip(input_names, input_shapes, input_dtypes_as_str)] else: input_infos = [(n, s) for n, s in zip(input_names, input_shapes)] return _pytorch_to_tvm_ir(script_module, input_infos, custom_convert_map)