afe.load.importers.pytorch

Functions

import_pytorch_script_module(file_path, input_shapes)

Load a PyTorch scripted model from a file path and a list of input shapes

validate_input_parameters(→ None)

Validates the user supplied input.

import_pytorch_to_tvm(→ afe._tvm._defines.TVMIRModule)

Module Contents

afe.load.importers.pytorch.import_pytorch_script_module(file_path: str, input_shapes: List[Tuple[int, Ellipsis]])[source]

Load a PyTorch scripted model from a file path and a list of input shapes :param file_path: str. Path to a PyTorch file (.pt) that contains the entire model :param input_shapes: List[Tuple[int, …]]. List of input shapes corresponding to input names. eg. [(1, 224, 224, 3)] :return: A PyTorch scripted model

afe.load.importers.pytorch.validate_input_parameters(pt_file_path: str, input_names: List[str], input_shapes: List[Tuple[int, Ellipsis]], input_dtypes: List[str] | None) None[source]

Validates the user supplied input. :param pt_file_path: Path to a PyTorch file (.pt) that contains the entire model :param input_names: List of input names. eg [‘input’] :param input_shapes: List of input shapes corresponding to input names. eg. [(1, 224, 224, 3)] :param input_dtypes:List of input datatypes corresponding to input names. eg [‘float32’]

afe.load.importers.pytorch.import_pytorch_to_tvm(pt_file_path: str, input_names: List[str], input_shapes: List[Tuple[int, Ellipsis]], input_dtypes: List[afe.ir.tensor_type.ScalarType] | None, custom_convert_map: Dict[str, Callable] | None = None) afe._tvm._defines.TVMIRModule[source]
Parameters:
  • pt_file_path – Path to a PyTorch file (.pt) that contains the entire model

  • input_names – List of input names. eg [‘input’]

  • input_shapes – List of input shapes corresponding to input names. eg. [(1, 224, 224, 3)]

:param input_dtypes:List of input datatypes corresponding to input names. eg [‘float32’] :param custom_convert_map: A custom op conversion map that maps operation names to functions. Whenever an operator

with a name found in the custom_convert_map is found in TVM, the function is called with 2 arguments:

inputs = tvm relay expression inputs to operator. input_types = list of strings indicating the input types to the operator.

The function then returns the tvm relay IR expression that is inserted into the model wherever the operation occurs.

Returns:

Imported TVM IR module.