Source code for afe.backends.mpk.node

#########################################################
# Copyright (C) 2022 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Stefan Jovic
#########################################################
"""
Functions for making MPK JSON node descriptions.
"""
from typing import List, TypeVar, Dict

from afe.backends.mpk.defines import PluginInputNodeMPKData, InOutNodesMPKData
from afe.ir.defines import NodeName, DataValue, TensorValue, TupleValue
from afe.ir.node import AwesomeNode
from afe.ir.tensor_type import data_byte_size, TensorType

_A = TypeVar("_A")


[docs] def flatten_tuple(data: DataValue[_A]) -> List[_A]: """ Convert un-nested tuples to a list. Given a tuple of 0, 2, or more tensors, the tuple elements are returned. Given a tensor, the tensor is returned as a list of 1 item. An exception is raised for other inputs. :param data: Data value :return: Flattened data value """ if isinstance(data, TupleValue): # To avoid ambiguity between tuple and non-tuple, we can't allow a tuple of 1 item if len(data.elements) == 1: raise ValueError("Tuple of 1 element cannot be flattened") if not all(isinstance(e, TensorValue) for e in data.elements): raise ValueError("Nested tuples cannot be flattened") return [e.value for e in data.elements] else: assert isinstance(data, TensorValue) return [data.value]
[docs] def get_node_size(node: AwesomeNode) -> int: """ Get node size in bytes. :param node: AwesomeNode :return: Number of bytes. """ node_type = node.get_type() size = int(data_byte_size(node_type.output)) return size
[docs] def get_output_node_names(node_name: str, output_type: DataValue[TensorType], output_names: Dict[str, str]) -> DataValue[str]: """ Get the output node names. Check if output node is an output of the model and if it is original model name is added to the ModelSDK name. """ if isinstance(output_type, TensorValue): out_node_name = node_name + '/' + output_names[node_name] if node_name in output_names.keys() else node_name return TensorValue(out_node_name) else: assert isinstance(output_type, TupleValue) output_node_names = [] for idx in range(len(output_type.elements)): out_node_name = f'{node_name}_{idx}' out_node_name = out_node_name + '/' + output_names[out_node_name] if out_node_name in output_names.keys()\ else out_node_name output_node_names.append(TensorValue(out_node_name)) return TupleValue(output_node_names)
[docs] def get_plugin_input_nodes(input_nodes: List[InOutNodesMPKData]) -> List[PluginInputNodeMPKData]: """ Convert a list of MPK JSON nodes to a list of inputs for MLA or EV74 plugin. :param input_node: MPK JSON objects representing inputs to the AwesomeNode. :return: PluginInputNodeMPKData class with input node. """ return [PluginInputNodeMPKData(name=input_node.name, size=input_node.size) for input_node in input_nodes]