Source code for afe.load.importers.tensorflow

from typing import List, Tuple, Dict, Optional, Callable
from afe._tvm._defines import TVMIRModule


def _process_graph_def_param(graph_def):
    """
    Taken from: https://github.com/apache/tvm/blob/main/python/tvm/relay/testing/tf.py#L45
    Type-checks and possibly canonicalizes `graph_def`.

    Parameters
    ----------
    graph_def : Obj
        tensorflow graph definition.

    Returns
    -------
    graph_def : Obj
        tensorflow graph definition

    """
    from tensorflow.core.framework import graph_pb2
    if not isinstance(graph_def, graph_pb2.GraphDef):
        # `graph_def` could be a dynamically-created message, so try a duck-typed
        # approach
        try:
            old_graph_def = graph_def
            graph_def = graph_pb2.GraphDef()
            graph_def.MergeFrom(old_graph_def)
        except TypeError:
            raise TypeError("graph_def must be a GraphDef proto.")
    return graph_def


[docs] def convert_to_list(x): """Taken from: https://github.com/apache/tvm/blob/main/python/tvm/relay/testing/tf.py#L72""" if not isinstance(x, list): x = [x] return x
def _add_shapes_to_graph_def(session, out_node): """ Taken from: https://github.com/apache/tvm/blob/main/python/tvm/relay/testing/tf.py#L118 Add shapes attribute to nodes of the graph. Input graph here is the default graph in context. Parameters ---------- session : tf.Session Tensorflow session out_node : String or List Final output node of the graph. Returns ------- graph_def : Obj tensorflow graph definition with shapes attribute added to nodes. """ import tensorflow as tf try: tf_compat_v1 = tf.compat.v1 except (ImportError, AttributeError): tf_compat_v1 = tf graph_def = tf_compat_v1.graph_util.convert_variables_to_constants( session, session.graph.as_graph_def(add_shapes=True), convert_to_list(out_node), ) return graph_def
[docs] def import_tensorflow_model(file_path: str, output_names: List[str]): """ Load a Tensorflow model from a pb filepath and a list of output names :param file_path: str. File path to the tensorflow .pb file :param output_names: List[str]. List of output tensor names (really only applies to tensorflow imports.) :return: A Tensorflow graph """ import tensorflow as tf with tf.compat.v1.gfile.GFile(file_path, 'rb') as f: graph_def = tf.compat.v1.GraphDef() graph_def.ParseFromString(f.read()) tf.import_graph_def(graph_def, name='') # Call the utility to import the graph definition into default graph. _process_graph_def_param(graph_def) # Add add_shapes attribute to the graph at the out_node with tf.compat.v1.Session() as sess: graph_def = _add_shapes_to_graph_def(sess, output_names) del tf return graph_def
[docs] def import_tensorflow2_model(file_path: str): """ Load a Tensorflow model from a pb filepath and a list of output names :param file_path: str. File path to the tensorflow SavedModel dir. :return: A Tensorflow graph """ import tensorflow as tf from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 tf2_model = tf.saved_model.load(file_path) func = tf2_model.signatures['serving_default'] frozen_func = convert_variables_to_constants_v2(func) graph_def = frozen_func.graph.as_graph_def(add_shapes=True) return graph_def
[docs] def validate_input_parameters(pb_file_path: str, shape_dict: Dict[str, Tuple[int, ...]], output_names: List[str], layout: str) -> None: """ Validates the user supplied input. Given a .pb file we should be able to discern the shape_dict and dtype_dict automatically and we can validate the user supplied output_names. The layout of the network must be given by the user. If the user has supplied their own input, this function will validate the input. If data cannot be discerned and the user has not supplied input this function will raise errors. :param pb_file_path: filepath to the tensorflow .pb file :param shape_dict: dictionary of input names to input shapes (eg. (1,224,224,3)) :param output_names: list of output tensor names (really only applies to tensorflow imports.) :param layout: any variation of the characters NHWC representing Batch Size, Height, Width, and Channels """ import tensorflow as tf assert '.pb' in pb_file_path, "Error: We expect a .pb file to be supplied for a TensorFlow import" assert len(output_names) > 0, "Error: User needs to supply output names" assert layout, "Error: User needs to supply a layout" # Create an interactive session tf.compat.v1.reset_default_graph() sess = tf.compat.v1.InteractiveSession() # Load the graph with tf.compat.v1.gfile.GFile(pb_file_path, "rb") as f: graph_def = tf.compat.v1.GraphDef() graph_def.ParseFromString(f.read()) tf.import_graph_def(graph_def, name="") graph = tf.compat.v1.get_default_graph() graph_ops = graph.get_operations() # Gather detected TF node names, input shapes tf_shape_dict: Dict[str, Optional[Tuple[int, ...]]] = dict() node_names: List[str] = [] for op in graph_ops: node_name = op.name if op.type == 'Placeholder': input_shape = op.outputs[0].shape tf_shape_dict[node_name] = tuple(input_shape) if input_shape else None node_names.append(node_name) # Check that the user supplied output names are valid for name in output_names: if name not in node_names: raise ValueError("Output name ({}) not present in network.\nValid Node names include:\n{}" .format(name, "".join([n + "\n" for n in node_names]))) if shape_dict: # Check that all user provided keys are in the detected set of keys. assert all([key in tf_shape_dict for key in shape_dict.keys()]), \ "Error: Some user provided Inputs cannot be found in .pb files. User Inputs: ({}), TF Inputs ({}).". \ format(shape_dict.keys(), tf_shape_dict.keys()) # Validate user data with tf data. If tf data is missing ensure user data covers it. for key, tf_shape in tf_shape_dict.items(): assert key in shape_dict, "User must supply shape data for input ({}).".format(key) if tf_shape is None: continue user_shape = shape_dict[key] assert len(tf_shape) == len(user_shape), \ "Detected shape for input ({}) differs from user's: {} != {}". \ format(key, len(tf_shape), len(user_shape)) for s_tf, s_user in zip(tf_shape, user_shape): if s_tf is not None: assert s_tf == s_user, \ "Non-None detected shape values for input ({}) differ from user shape: {} != {}". \ format(key, tf_shape, user_shape) sess.close() del tf
[docs] def import_tensorflow_pb_to_tvm(pb_file_path: str, shape_dict: Dict[str, Tuple[int, ...]], output_names: List[str], layout: str, custom_convert_map: Optional[Dict[str, Callable]] = None ) -> TVMIRModule: """ Use TVM frontend to import a tensorflow pb model into TVM Relay IR :param pb_file_path: filepath to the tensorflow .pb file :param shape_dict: dictionary of input names to input shapes (eg. (1,224,224,3)) :param output_names: list of output tensor names (really only applies to tensorflow imports.) :param layout: any variation of the characters NHWC representing Batch Size, Height, Wdith, and Channels :param custom_convert_map: A custom op conversion map that maps operation names to functions. Whenever an operator with a name found in the custom_convert_map is found in TVM, the function is called with 4 arguments: inputs = tvm relay expression inputs to operator. attr = list of strings to operation attributes. params = list of strings to tvm runtime arrays that are constants in the network. mod = The tvm irmodule containing subgraphs it uses to help construct the main graph from tensorflow. The function then returns the tvm relay IR expression that is inserted into the model wherever the operation occurs. :return: TVM IR module. """ from afe._tvm._importers._tensorflow_importer import _tensorflow_to_tvm_ir validate_input_parameters(pb_file_path, shape_dict, output_names, layout) graph_def = import_tensorflow_model(pb_file_path, output_names) outputs = [name + ":0" for name in output_names] return _tensorflow_to_tvm_ir(graph_def, layout, shape_dict, outputs, custom_convert_map)
[docs] def import_tensorflow2_pb_to_tvm(saved_model_path: str, shape_dict: Dict[str, Tuple[int, ...]], output_names: List[str], layout: str, custom_convert_map: Optional[Dict[str, Callable]] = None) -> TVMIRModule: """ Use TVM frontend to import a tensorflow pb model into TVM Relay IR :param saved_model_path: path to the tensorflow SavedModel directory. :param shape_dict: dictionary of input names to input shapes (eg. (1,224,224,3)) :param output_names: list of output tensor names (really only applies to tensorflow imports.) :param layout: any variation of the characters NHWC representing Batch Size, Height, Wdith, and Channels :param custom_convert_map: A custom op conversion map that maps operation names to functions. Whenever an operator with a name found in the custom_convert_map is found in TVM, the function is called with 4 arguments: inputs = tvm relay expression inputs to operator. attr = list of strings to operation attributes. params = list of strings to tvm runtime arrays that are constants in the network. mod = The tvm irmodule containing subgraphs it uses to help construct the main graph from tensorflow. The function then returns the tvm relay IR expression that is inserted into the model wherever the operation occurs. :return: TVM IR module. """ from afe._tvm._importers._tensorflow_importer import _tensorflow2_to_tvm_ir graph_def = import_tensorflow2_model(saved_model_path) outputs = ["Identity:0"] if len(output_names) > 1: outputs.append("Identity_" + str(i) + ":0" for i in range(len(output_names) - 1)) return _tensorflow2_to_tvm_ir(graph_def, layout, shape_dict, outputs, custom_convert_map)