#########################################################
# Copyright (C) 2023 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Christopher Rodrigues
#########################################################
from typing import Union, Tuple, Dict, Optional
import os
import numpy as np
import tvm.rpc.client
import tvm.runtime
import tvm
from tvm.contrib.graph_executor import GraphModule
from afe.backends import BackendIR, Backend
from afe.backends.apu.tvm_apu_compiler import CompiledTVMObjectFile
from afe.ir.defines import NodeName, InputName, TensorValue, TupleValue
from afe.ir.execute import execute_node_quant
from afe.ir.node import AwesomeNode
from afe.ir.sima_ir import SiMaIRTensorTypes
from afe.ir.tensor_type import NodeType
def _execute_module(session: tvm.rpc.client.RPCSession, object_file: CompiledTVMObjectFile, node_type: NodeType,
inputs: Dict[InputName, np.ndarray]) \
-> Union[np.ndarray, Tuple[np.ndarray, ...]]:
"""
Execute a module using a TVM RPC server.
:param session: TVM RPC session
:param object_file: Compiled module to execute
:param node_type: The module's type
:param inputs: Input data to pass to the module
:return: Return value of executing the module with the inputs
"""
# Get handle to remote processor
dev = session.cpu()
# Marshal and run the executable module
remote_filename = os.path.basename(object_file.path)
session.upload(object_file.path, target=remote_filename)
func = session.load_module(remote_filename)
module = GraphModule(func["default"](dev))
# Execute the module
assert list(node_type.inputs.keys()) == list(inputs.keys())
for k, v in inputs.items():
module.set_input(k, v)
module.run()
local_outputs = [module.get_output(i).numpy() for i in range(module.get_num_outputs())]
session.remove(remote_filename)
if isinstance(node_type.output, TensorValue):
# Single output
output, = local_outputs
else:
assert isinstance(node_type.output, TupleValue)
output = tuple(local_outputs)
return output
[docs]
def is_arm_backend_node(node: AwesomeNode) -> bool:
"""
Test whether the node is compiled backend code for ARM.
"""
return isinstance(node.ir, BackendIR) and node.ir.backend == Backend.APU
[docs]
class ARMBackendRunner:
"""
Implementation of execution on an ARM processor,
conforming to the protocol of AwesomeNet.run.
"""
_session: tvm.rpc.client.RPCSession
def __init__(self, host: str, port: int):
"""
Initialize the runner and connect to the RPC server
"""
self._session = tvm.rpc.client.connect(host, port, session_timeout=240.0)
def _execute_apu_node(self, node: AwesomeNode, inputs: Dict[InputName, SiMaIRTensorTypes],
node_outputs: Dict[NodeName, Union[np.ndarray, Tuple[np.ndarray, ...]]]):
assert is_arm_backend_node(node)
assert isinstance(node.ir.graph, CompiledTVMObjectFile)
outputs = _execute_module(self._session, node.ir.graph, node.ir.type, inputs)
# Store outputs in the dict
node_outputs[node.name] = outputs
[docs]
def execute_node(self, node: AwesomeNode, inputs: Dict[InputName, SiMaIRTensorTypes],
node_outputs: Dict[NodeName, Union[np.ndarray, Tuple[np.ndarray, ...]]]):
if is_arm_backend_node(node):
self._execute_apu_node(node, inputs, node_outputs)
else:
execute_node_quant(node, inputs, node_outputs)