from typing import List, Tuple, Dict, Optional, Callable
from afe._tvm._defines import TVMIRModule, TVMNDArray
from afe.ir.tensor_type import ScalarType, scalar_type_to_dtype
[docs]
def import_pytorch_script_module(file_path: str, input_shapes: List[Tuple[int, ...]]):
"""
Load a PyTorch scripted model from a file path and a list of input shapes
:param file_path: str. Path to a PyTorch file (.pt) that contains the entire model
:param input_shapes: List[Tuple[int, ...]]. List of input shapes corresponding to input names. eg. [(1, 224, 224, 3)]
:return: A PyTorch scripted model
"""
import torch
model = torch.load(file_path, map_location=torch.device('cpu'))
model.eval()
# We grab the TorchScripted model via tracing
input_data = tuple([torch.randn(shape) for shape in input_shapes])
scripted_module = torch.jit.trace(model, input_data).eval()
del model
del torch
return scripted_module
[docs]
def import_pytorch_to_tvm(pt_file_path: str,
input_names: List[str],
input_shapes: List[Tuple[int, ...]],
input_dtypes: Optional[List[ScalarType]],
custom_convert_map: Optional[Dict[str, Callable]] = None
) -> TVMIRModule:
"""
:param pt_file_path: Path to a PyTorch file (.pt) that contains the entire model
:param input_names: List of input names. eg ['input']
:param input_shapes: List of input shapes corresponding to input names. eg. [(1, 224, 224, 3)]
:param input_dtypes:List of input datatypes corresponding to input names. eg ['float32']
:param custom_convert_map: A custom op conversion map that maps operation names to functions. Whenever an operator
with a name found in the custom_convert_map is found in TVM, the function is called with
2 arguments:
inputs = tvm relay expression inputs to operator.
input_types = list of strings indicating the input types to the operator.
The function then returns the tvm relay IR expression that is inserted into the model
wherever the operation occurs.
:return: Imported TVM IR module.
"""
from afe._tvm._importers._pytorch_importer import _pytorch_to_tvm_ir
validate_input_parameters(pt_file_path, input_names, input_shapes, input_dtypes)
script_module = import_pytorch_script_module(pt_file_path, input_shapes)
if input_dtypes is not None:
input_dtypes_as_str = list()
for dtype in input_dtypes:
input_dtypes_as_str.append(scalar_type_to_dtype(dtype))
input_infos = [(n, (s, d)) for n, s, d in zip(input_names, input_shapes, input_dtypes_as_str)]
else:
input_infos = [(n, s) for n, s in zip(input_names, input_shapes)]
return _pytorch_to_tvm_ir(script_module, input_infos, custom_convert_map)