#########################################################
# Copyright (C) 2022 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Nenad Nikolic
#########################################################
"""
The command line driver code. This code executes the
commands that were constructed by the command line parser.
"""
from typing import Optional, Callable, Tuple
import numpy as np
import os
from afe.apis.compilation_job import CompilationJob
from afe.core.compile_networks import compile_net_to_elf
from afe.driver.passes import import_and_transform
from afe.ir.net import AwesomeNet
# This global variable is used for passing data from
# a user module to the driver.
from afe.ir.tensor_type import ScalarType
from afe.load.importers.general_importer import ImporterParams, ModelFormat, make_model_name_from_path
[docs]
saved_compilation_job: Optional[CompilationJob] = None
[docs]
def receive_compilation_job(process: Callable[[], None]) -> CompilationJob:
"""
Run the given code to set the global compilation job and then
read the value that was set.
"""
global saved_compilation_job
saved_compilation_job = None
process()
j = saved_compilation_job
saved_compilation_job = None
if j is None:
raise ValueError("Did not receive a CompilationJob")
return j
def _parse_param_type(s: str) -> ScalarType:
return ScalarType.from_numpy(np.dtype(s))
def _parse_param_shape(s: str) -> Tuple[int, ...]:
return tuple(int(n) for n in s.split('-'))
[docs]
def interpret_importer_params(args) -> ImporterParams:
"""
Make an ImporterParams from the data in command line arguments that describes how to import
a model.
:param args: Parsed command line arguments
"""
format = ModelFormat(args.framework) if hasattr(args, 'framework') else None
input_names = args.param_names.split(',') if hasattr(args, 'param_names') else None
input_types = [_parse_param_type(t) for t in args.param_dtypes.split(',')] \
if hasattr(args, 'param_dtypes') else None
input_shapes = [_parse_param_shape(t) for t in args.param_shapes.split(',')] \
if hasattr(args, 'param_shapes') else None
layout = args.layout if hasattr(args, 'layout') else None
return ImporterParams(format, file_paths=[args.input_path], input_names=input_names, input_types=input_types,
layout=layout,
input_shapes=input_shapes)
def _load_model(args) -> AwesomeNet:
"""
Private loading function for CLI purposes only.
"""
importer_params = interpret_importer_params(args)
name = make_model_name_from_path(args.input_path)
return import_and_transform(importer_params, name=name, is_quantized=args.quantized).run()
[docs]
def compile_command(args):
net = _load_model(args)
_ = compile_net_to_elf(net, args.output_path, enable_large_tensors=True)
print("Compiled")