Source code for afe.driver.cli.commands

#########################################################
# Copyright (C) 2022 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Nenad Nikolic
#########################################################
"""
The command line driver code.  This code executes the
commands that were constructed by the command line parser.
"""
from typing import Optional, Callable, Tuple

import numpy as np
import os

from afe.apis.compilation_job import CompilationJob
from afe.core.compile_networks import compile_net_to_elf
from afe.driver.passes import import_and_transform
from afe.ir.net import AwesomeNet


# This global variable is used for passing data from
# a user module to the driver.
from afe.ir.tensor_type import ScalarType
from afe.load.importers.general_importer import ImporterParams, ModelFormat, make_model_name_from_path

[docs] saved_compilation_job: Optional[CompilationJob] = None
[docs] def receive_compilation_job(process: Callable[[], None]) -> CompilationJob: """ Run the given code to set the global compilation job and then read the value that was set. """ global saved_compilation_job saved_compilation_job = None process() j = saved_compilation_job saved_compilation_job = None if j is None: raise ValueError("Did not receive a CompilationJob") return j
def _parse_param_type(s: str) -> ScalarType: return ScalarType.from_numpy(np.dtype(s)) def _parse_param_shape(s: str) -> Tuple[int, ...]: return tuple(int(n) for n in s.split('-'))
[docs] def interpret_importer_params(args) -> ImporterParams: """ Make an ImporterParams from the data in command line arguments that describes how to import a model. :param args: Parsed command line arguments """ format = ModelFormat(args.framework) if hasattr(args, 'framework') else None input_names = args.param_names.split(',') if hasattr(args, 'param_names') else None input_types = [_parse_param_type(t) for t in args.param_dtypes.split(',')] \ if hasattr(args, 'param_dtypes') else None input_shapes = [_parse_param_shape(t) for t in args.param_shapes.split(',')] \ if hasattr(args, 'param_shapes') else None layout = args.layout if hasattr(args, 'layout') else None return ImporterParams(format, file_paths=[args.input_path], input_names=input_names, input_types=input_types, layout=layout, input_shapes=input_shapes)
def _load_model(args) -> AwesomeNet: """ Private loading function for CLI purposes only. """ importer_params = interpret_importer_params(args) name = make_model_name_from_path(args.input_path) return import_and_transform(importer_params, name=name, is_quantized=args.quantized).run()
[docs] def compile_command(args): net = _load_model(args) _ = compile_net_to_elf(net, args.output_path, enable_large_tensors=True) print("Compiled")