#########################################################
# Copyright (C) 2022 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Stefan Jovic
#########################################################
"""
Functions for making MPK JSON node descriptions.
"""
from typing import List, TypeVar, Dict
from afe.backends.mpk.defines import PluginInputNodeMPKData, InOutNodesMPKData
from afe.ir.defines import NodeName, DataValue, TensorValue, TupleValue
from afe.ir.node import AwesomeNode
from afe.ir.tensor_type import data_byte_size, TensorType
_A = TypeVar("_A")
[docs]
def flatten_tuple(data: DataValue[_A]) -> List[_A]:
"""
Convert un-nested tuples to a list.
Given a tuple of 0, 2, or more tensors, the tuple elements are returned.
Given a tensor, the tensor is returned as a list of 1 item.
An exception is raised for other inputs.
:param data: Data value
:return: Flattened data value
"""
if isinstance(data, TupleValue):
# To avoid ambiguity between tuple and non-tuple, we can't allow a tuple of 1 item
if len(data.elements) == 1:
raise ValueError("Tuple of 1 element cannot be flattened")
if not all(isinstance(e, TensorValue) for e in data.elements):
raise ValueError("Nested tuples cannot be flattened")
return [e.value for e in data.elements]
else:
assert isinstance(data, TensorValue)
return [data.value]
[docs]
def get_node_size(node: AwesomeNode) -> int:
"""
Get node size in bytes.
:param node: AwesomeNode
:return: Number of bytes.
"""
node_type = node.get_type()
size = int(data_byte_size(node_type.output))
return size
[docs]
def get_output_node_names(node_name: str, output_type: DataValue[TensorType],
output_names: Dict[str, str]) -> DataValue[str]:
"""
Get the output node names.
Check if output node is an output of the model and if it is
original model name is added to the ModelSDK name.
"""
if isinstance(output_type, TensorValue):
out_node_name = node_name + '/' + output_names[node_name] if node_name in output_names.keys() else node_name
return TensorValue(out_node_name)
else:
assert isinstance(output_type, TupleValue)
output_node_names = []
for idx in range(len(output_type.elements)):
out_node_name = f'{node_name}_{idx}'
out_node_name = out_node_name + '/' + output_names[out_node_name] if out_node_name in output_names.keys()\
else out_node_name
output_node_names.append(TensorValue(out_node_name))
return TupleValue(output_node_names)