Source code for afe.apis.loaded_net

#########################################################
# Copyright (C) 2022 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Ljubomir Papuga
#########################################################
import copy
import logging
import os.path
import tempfile

import numpy as np
from typing import Iterable, Callable, TypeVar
import dataclasses

import afe._tvm._defines as _tvm_def
from afe._tvm._tvm_graph_partition import CompileMode
from afe._tvm._utils import create_ir_evaluator, run_ir_evaluator
import afe.apis.defines
from afe.apis.defines import (
    InputValues, QuantizationParams, ExceptionFuncType, gen1_target, gen2_target, MinMaxMethod, HistogramMSEMethod,
    MovingAverageMinMaxMethod, HistogramPercentileMethod, HistogramEntropyMethod, quantization_scheme
)
from afe.apis._sanitize_errors import sanitize_exceptions as _sanitize_exceptions
from afe.apis._sanitize_errors import sanitize_afe_error as _sanitize_afe_error
from afe.apis.model import Model
from afe.core.configs import AfeProcessingConfigs, ModelConfigs, TransformerConfigs, OptimizationConfigs, \
    api_calibration_configs, QuantizationPrecision, create_quantization_configs
from afe.core.configs import ConvertLayoutMethod as _ConvertLayoutMethod
from afe.core.graph_manager import transform_irmod_to_awesomenet as _transform_irmod_to_awesomenet
from afe.driver.passes import calibration_quantization as _calibration_quantization
from afe.driver.passes import noise_based_mixed_precision_quantization as _noise_based_mixed_precision_quantization
from afe.driver.passes import MIXED_PRECISION_SEARCH_LIMIT as _MIXED_PRECISION_SEARCH_LIMIT
from afe.driver.passes import import_model as _import_model
from afe.driver.statistic import Statistic

from afe.ir.defines import Status, RequantizationMode
from afe.ir.tensor_type import ScalarType, scalar_type_from_dtype
from afe.ir.utils import transpose_tensor_according_to_layout_strings
from afe.load.importers.general_importer import ImporterParams, detect_format, ModelFormat, onnx_source, \
    default_layout, pytorch_source, keras_source, tensorflow_source, tensorflow2_source, tflite_source
from sima_utils.logging import sima_logger
from sima_utils.common import Platform


[docs] GroundTruth = TypeVar('GroundTruth')
def _get_integer_activation_quantization_scheme( s: afe.apis.defines.QuantizationScheme ) -> afe.apis.defines.QuantizationScheme: """ Get the integer quantization scheme to use for activations. If s already designates integer quantization, return it. Otherwise, return a suitable fallback quantization scheme for when bf16 is not supported. """ if s.bf16: return quantization_scheme(asymmetric=True, per_channel=False, bits=16) else: return s def _get_integer_weight_quantization_scheme( s: afe.apis.defines.QuantizationScheme ) -> afe.apis.defines.QuantizationScheme: """ Get the integer quantization scheme to use for weights. If s already designates integer quantization, return it. Otherwise, return a suitable fallback quantization scheme for when bf16 is not supported. """ if s.bf16: return quantization_scheme(asymmetric=False, per_channel=True, bits=8) else: return s def _update_optimization_configs_with_quant_config(quantization_config: QuantizationParams) -> OptimizationConfigs: """ Updates optimization configuration with given quantization configuration. Args: quantization_config: Quantization configuration. Returns: OptimizationConfigs. """ # Get floating-point and integer parameters from the quantization schemes bfloat16 = quantization_config.activation_quantization_scheme.bf16 weights_bfloat16 = quantization_config.weight_quantization_scheme.bf16 weights_bits = quantization_config.weight_quantization_scheme.bits integer_activation_q_scheme = \ _get_integer_activation_quantization_scheme(quantization_config.activation_quantization_scheme) integer_weight_q_scheme = _get_integer_weight_quantization_scheme(quantization_config.weight_quantization_scheme) asymmetric = integer_activation_q_scheme.asymmetric per_channel = integer_weight_q_scheme.per_channel calibration_method = quantization_config.calibration_method if bfloat16: if weights_bfloat16: quantization_precision = QuantizationPrecision.BFLOAT_16 elif weights_bits == 8: quantization_precision = QuantizationPrecision.BFLOAT_16_INT8_WEIGHTS else: assert weights_bits == 4 quantization_precision = QuantizationPrecision.BFLOAT_16_INT4_WEIGHTS elif integer_activation_q_scheme.bits == 16: quantization_precision = QuantizationPrecision.INT_16 else: quantization_precision = QuantizationPrecision.INT_8 requantization_mode = quantization_config.requantization_mode biascorr_type = quantization_config.biascorr_type channel_equalization = quantization_config.channel_equalization smooth_quant = quantization_config.smooth_quant quantization_configs = create_quantization_configs(asymmetry=asymmetric, per_channel=per_channel, quantization_precision=quantization_precision, requantization_mode=requantization_mode, biascorr_type=biascorr_type, channel_equalization=channel_equalization, smooth_quant=smooth_quant) calibration_configs = api_calibration_configs(calibration_method=calibration_method) return OptimizationConfigs(calibration_configs=calibration_configs, quantization_configs=quantization_configs) def _create_afe_processing_configs(quantization_config: QuantizationParams, is_quantized: bool = False, layout: str = "NCHW", name: str = "", *, enabled_backends: CompileMode, requantization_mode: RequantizationMode = RequantizationMode.sima, automatic_layout_conversion: bool = False, target: Platform) -> AfeProcessingConfigs: """ Constructs and returns an `AfeProcessingConfigs` object based on the provided quantization settings and platform-specific configuration. This function builds the processing configuration required for AFE (accuracy-feedback engine) operations, such as quantization, layout transformation, and backend optimization. It combines model layout, backend selection, and optimization policies into a single configuration object. Args: quantization_config (QuantizationParams): Configuration parameters that define how quantization should be applied to the model. is_quantized (bool, optional): Indicates whether the model is already quantized. Defaults to ``False``. layout (str, optional): Model tensor layout. Supported options are typically ``"NCHW"`` or ``"NHWC"``. Defaults to ``"NCHW"``. name (str, optional): Name assigned to the model for internal tracking or debugging. Defaults to an empty string. enabled_backends (CompileMode): Specifies which backend(s) (e.g., MLA, ARM) the model should target for partitioning and compilation. requantization_mode (RequantizationMode, optional): Mode used for requantizing intermediate outputs. Options include ``sima`` (default) or ``tflite``. automatic_layout_conversion (bool, optional): If ``True``, layout conversion is handled automatically using modern conversion logic. If ``False``, legacy layout transformation is used. Defaults to ``False``. target (Platform): The target hardware platform the model will be compiled and optimized for. Returns: AfeProcessingConfigs: A configuration object that encapsulates model parameters, backend options, optimization rules, and target platform for downstream AFE processing. """ convert_layout_method = _ConvertLayoutMethod.AUTOMATED if automatic_layout_conversion \ else _ConvertLayoutMethod.LEGACY model_configs = ModelConfigs(name=name, framework="", input_names=[], input_shapes=[], input_dtypes=[], layout=layout, is_quantized=is_quantized) transformer_configs = TransformerConfigs(enabled_backends=enabled_backends, requantization_mode=requantization_mode, convert_layout_method=convert_layout_method) optimization_configs = _update_optimization_configs_with_quant_config(quantization_config) return AfeProcessingConfigs(model_configs, transformer_configs, optimization_configs, target=target) def _validate_quantization_configs(quantization_config: QuantizationParams, target: Platform): """ Helper function for validation of quantization parameters. Args: quantization_config: Quantization parameters that are to be validated. target: Hardware platform to be validated against. Returns: None. This function shall raise an exception if quantization parameters are illegal. """ config_error = "" if quantization_config.calibration_method and not isinstance(quantization_config.calibration_method, (MinMaxMethod, HistogramMSEMethod, MovingAverageMinMaxMethod, HistogramPercentileMethod, HistogramEntropyMethod)): config_error += f"Unsupported CalibrationMethod {quantization_config.calibration_method}.\n" activation_quantization_scheme = quantization_config.activation_quantization_scheme weight_quantization_scheme = quantization_config.weight_quantization_scheme if weight_quantization_scheme.bf16 and not activation_quantization_scheme.bf16: config_error += "Cannot use floating-point weights tensors with integer activation tensors.\n" if (weight_quantization_scheme.bf16 or activation_quantization_scheme.bf16) and target != gen2_target: config_error += "BFLOAT16 is only supported on hardware platform gen2.\n" if not activation_quantization_scheme.bf16: # Check integer-related quantization scheme fields if activation_quantization_scheme.bits not in (8, 16): config_error += "Activation tensors can only be quantized using 8 or 16 bits of precision.\n" if activation_quantization_scheme.per_channel is not False: config_error += "For activations, per-channel quantization scheme is not supported.\n" if not weight_quantization_scheme.bf16: # Check integer-related quantization scheme fields if not activation_quantization_scheme.bf16 and weight_quantization_scheme.bits != 8: config_error += ("If integer activations is used, weights tensors can only be " "quantized using 8 bits of precision.\n") if activation_quantization_scheme.bf16 and weight_quantization_scheme.bits not in [4, 8]: config_error += ("If BFLOAT16 activations is used, weights tensors can only be " "quantized using 4 or 8 bits of precision.\n") if weight_quantization_scheme.asymmetric is not False: config_error += "For weights, asymmetric quantization scheme is not supported.\n" if isinstance(quantization_config.calibration_method, HistogramPercentileMethod) and \ quantization_config.calibration_method.percentile_value is None: config_error += "Percentile calibration method chosen, but percentile value not set. Please set " \ "percentile value in QuantizationParams. \n" if isinstance(quantization_config.calibration_method, (HistogramMSEMethod, HistogramEntropyMethod, HistogramPercentileMethod)) and \ quantization_config.calibration_method.num_bins is None: config_error += "Histogram based calibration method chosen, but the number of bins is not set. " \ f"Please set the number of bins in {quantization_config.calibration_method}. \n" if config_error != "": sima_logger.sima_log_error(config_error) raise sima_logger.UserFacingException(config_error)
[docs] class LoadedNet: _ir_mod: _tvm_def.TVMIRModule _layout: str _target: Platform _mod_evaluator: Callable = None # Names of model's outputs, if available. These names are not stored by TVM. # If there are multiple names, the names pertain to individual tuple elements. _output_labels: list[str] | None # Path to the original model _model_path: str | None def __init__(self, mod: _tvm_def.TVMIRModule, layout: str, target: Platform, *, output_labels: list[str] | None, model_path: str | None): self._ir_mod = mod self._layout = layout self._output_labels = output_labels self._target = target self._model_path = model_path def _shape_dict(self) -> dict[str, tuple[int, ...]]: import tvm.relay.transform _mod = tvm.relay.transform.InferType()(self._ir_mod) return {p.name_hint: tuple(sh.value for sh in p.checked_type.shape) for p in _mod['main'].params} def _dtype_dict(self) -> dict[str, ScalarType]: import tvm.relay.transform _mod = tvm.relay.transform.InferType()(self._ir_mod) return {p.name_hint: scalar_type_from_dtype(p.checked_type.dtype) for p in _mod['main'].params} @_sanitize_exceptions(ExceptionFuncType.LOADED_NET_EXECUTE)
[docs] def execute(self, inputs: InputValues, *, log_level: int = logging.NOTSET) -> list[np.ndarray]: """ Execute the loaded network using a software implementation of operators. This method runs the network with a single set of input tensor values and returns the corresponding output tensor values. The execution does not simulate processor behavior but instead uses TVM operators for both FP32 and quantized models. Input and output tensors are automatically transposed if the model layout requires it. Args: inputs (InputValues): A dictionary mapping input names to their corresponding tensor data. Input tensors must be in channel-last layout (e.g., NHWC or NDHWC). log_level (Optional[int], optional): Sets the logging level for this API call. Defaults to ``logging.NOTSET``. Returns: list[np.ndarray]: A list of output tensors resulting from the model execution. Raises: UserFacingException: If an error occurs during the execution process. Execution Details: - Inputs are automatically transposed to match the model's expected layout if necessary. - Outputs are also transposed back to channel-last layout for consistency with API requirements. - Supports 4D (NCHW/NHWC) and 5D (NCDHW/NDHWC) tensor formats. """ with sima_logger.ScopedLogLevel(log_level): if self._mod_evaluator is None: self._mod_evaluator = create_ir_evaluator(self._ir_mod) # API mandates input tensors have channel-last layout. # Do transpose for 4D and 5D input if the model itself has channel-first layout. transposed_inputs = dict() for input_name in inputs: if len(inputs[input_name].shape) == 4 and self._layout == "NCHW": sima_logger.sima_log_info( f"Transposing inputs shape {inputs[input_name].shape} from NHWC to NCHW." ) transposed_inputs[input_name] = transpose_tensor_according_to_layout_strings( inputs[input_name], "NHWC", "NCHW" ) elif len(inputs[input_name].shape) == 5 and self._layout == "NCDHW": sima_logger.sima_log_info( f"Transposing inputs shape {inputs[input_name].shape} from NDHWC to NCDHW." ) transposed_inputs[input_name] = transpose_tensor_according_to_layout_strings( inputs[input_name], "NDHWC", "NCDHW" ) else: transposed_inputs[input_name] = inputs[input_name] sima_logger.sima_log_info("Executing loaded net with input data\n\t" + f"\n\t".join([f"{name}: {input.shape}" for name, input in transposed_inputs.items()])) output = run_ir_evaluator(self._ir_mod, self._mod_evaluator, transposed_inputs) # API mandates output tensors have channel-last layout. If the model has channel-first layout, # do transpose if output dimension matches model layout. if all([len(t.shape) == 4 for t in output]) and self._layout == "NCHW": for idx in range(len(output)): sima_logger.sima_log_info(f"Transposing output shape {output[idx].shape} from NCHW to NHWC.") output[idx] = transpose_tensor_according_to_layout_strings(output[idx], "NCHW", "NHWC") elif all([len(t.shape) == 5 for t in output]) and self._layout == "NCDHW": for idx in range(len(output)): sima_logger.sima_log_info(f"Transposing output shape {output[idx].shape} from NCDHW to NDHWC.") output[idx] = transpose_tensor_according_to_layout_strings(output[idx], "NCDHW", "NDHWC") return output
@_sanitize_exceptions(ExceptionFuncType.LOADED_NET_QUANTIZE)
[docs] def quantize(self, calibration_data: Iterable[InputValues], quantization_config: QuantizationParams, *, automatic_layout_conversion: bool = False, arm_only: bool = False, simulated_arm: bool = False, model_name: str | None = None, log_level: int = logging.NOTSET) -> Model: """ Quantize the loaded neural network model using the provided calibration data and quantization configuration. If ``arm_only`` is ``False``, the model is calibrated and quantized for efficient execution on the SiMa MLSoC. If ``arm_only`` is ``True``, quantization is skipped, and the model is compiled for ARM execution—useful for testing. Args: calibration_data (Iterable[InputValues]): Dataset for calibration. Each sample is a dictionary mapping input names to calibration data. quantization_config (QuantizationParams): Parameters controlling the calibration and quantization process. automatic_layout_conversion (bool, optional): Enable automatic layout conversion during processing. Defaults to ``False``. arm_only (bool, optional): Skip quantization and compile for ARM. Useful for testing. Defaults to ``False``. simulated_arm (bool, optional): Reserved for internal use. Simulates ARM backend behavior without compilation. Defaults to ``False``. model_name (Optional[str], optional): Name for the returned quantized model. Defaults to ``None``. log_level (int, optional): Logging level for this API call. Defaults to ``logging.NOTSET``. Returns: Model: The quantized model instance or an ARM-prepared model if ``arm_only`` is ``True``. Raises: ValueError: If an invalid combination of parameters is provided (e.g., both ``arm_only`` and ``simulated_arm`` set to ``True``). UserFacingException: If an error occurs during calibration or quantization. Example: .. code-block:: python # Load pre-processed calibration data dataset_f = np.load('preprocessed_data.npz') data = dataset_f['x'] # Prepare calibration data as a list of dictionaries calib_data = [] calib_images = 100 for i in range(calib_images): inputs = {'input_1': data[i]} calib_data.append(inputs) # Quantize the model quant_model = loaded_net.quantize( calibration_data=calib_data, quantization_config=default_quantization, model_name='quantized_model' ) """ with sima_logger.ScopedLogLevel(log_level): _validate_quantization_configs(quantization_config, self._target) sima_logger.sima_log_info(f"Quantize loaded net, layout = {self._layout}, arm_only = {arm_only}") sima_logger.sima_log_info(f"Calibration method = {quantization_config.calibration_method.name}") sima_logger.sima_log_dbg(f"Quantization configuration = {quantization_config.__dict__}") if simulated_arm: if arm_only: raise ValueError("Unsupported combination of parameters: arm_only=True and simulated_arm=True") else: # Use the "CPU" fake backend, which can be executed in the API but not compiled sima_logger.sima_log_info("Internal reserved flag simulated_arm is set to True") enabled_backends = CompileMode.MLA_EV74_CPU else: if arm_only: enabled_backends = CompileMode.A65 else: enabled_backends = CompileMode.MLA_EV74_A65 afe_processing_configs = _create_afe_processing_configs(quantization_config=quantization_config, is_quantized=False, layout=self._layout, name=model_name, enabled_backends=enabled_backends, automatic_layout_conversion=automatic_layout_conversion, target=self._target) fp32_net = _transform_irmod_to_awesomenet(self._ir_mod, afe_processing_configs, output_labels=self._output_labels, model_path=self._model_path) net = copy.deepcopy(fp32_net) assert net.status == Status.RELAY if not arm_only: try: calibrate_and_quantize_net = _calibration_quantization( afe_processing_configs.optimization_configs, custom_quantization_configs=quantization_config.custom_quantization_configs ) net = calibrate_and_quantize_net(net, calibration_data).run() except Exception as e: _sanitize_afe_error("Error occured during calibration and quantization process. " "Please verify that the calibration data has an NHWC layout. " "Contact SiMa Support if the error still persists.", e) assert net.status == Status.SIMA_QUANTIZED return Model(net, fp32_net)
@_sanitize_exceptions(ExceptionFuncType.LOADED_NET_QUANTIZE)
[docs] def quantize_with_accuracy_feedback(self, calibration_data: Iterable[InputValues], evaluation_data: Iterable[tuple[InputValues, GroundTruth]], quantization_config: QuantizationParams, *, accuracy_score: Statistic[tuple[list[np.ndarray], GroundTruth], float], target_accuracy: float, automatic_layout_conversion: bool = False, max_optimization_steps: int | None = None, model_name: str | None = None, log_level: int = logging.NOTSET) -> Model: """ Quantizes the model with accuracy feedback using a mixed-precision approach. This method performs quantization with iterative accuracy feedback to ensure the final model meets the specified target accuracy. The process involves calibrating the model, evaluating its accuracy, and adjusting precision through multiple optimization steps if necessary. Parameters: calibration_data (Iterable[InputValues]): Required. The dataset used for model calibration. Each sample is a dictionary mapping input names to corresponding calibration data. evaluation_data (Iterable[tuple[InputValues, GroundTruth]]): Required. The dataset used to evaluate model accuracy, where each element is a tuple containing input data and corresponding ground truth. quantization_config (QuantizationParams): Required. Configuration parameters that define how the quantization process is performed. accuracy_score (Statistic[tuple[list[np.ndarray], GroundTruth], float]): Required. The evaluation metric used to calculate accuracy during the quantization process. target_accuracy (float): Required. The target accuracy value that the quantized model must achieve. automatic_layout_conversion (bool, optional): Enables automatic layout conversion during processing. Defaults to ``False``. max_optimization_steps (Optional[int], optional): Maximum number of optimization steps for mixed-precision quantization. Must be greater than 1. Defaults to ``_MIXED_PRECISION_SEARCH_LIMIT`` if not specified. model_name (Optional[str], optional): The name for the resulting quantized model. Defaults to ``None``. log_level (Optional[int], optional): Sets the logging level for the process. Defaults to ``logging.NOTSET``. Returns: Model: The quantized model along with its corresponding floating-point model. Raises: UserFacingException: - If activation quantization parameters are unsupported (only 8-bit precision is supported). - If ``max_optimization_steps`` is less than or equal to 1. - If an error occurs during the mixed-precision quantization process. """ with sima_logger.ScopedLogLevel(log_level): _validate_quantization_configs(quantization_config, self._target) if quantization_config.activation_quantization_scheme.bf16 or \ quantization_config.activation_quantization_scheme.bits != 8: raise sima_logger.UserFacingException( "Unsupported quantization parameters. Activation tensors must use 8 bits of precision." ) if max_optimization_steps is None: max_optimization_steps = _MIXED_PRECISION_SEARCH_LIMIT if max_optimization_steps <= 1: raise sima_logger.UserFacingException("Parameter max_optimization_steps must be greater than 1.") afe_processing_configs = _create_afe_processing_configs(quantization_config=quantization_config, is_quantized=False, layout=self._layout, name=model_name, enabled_backends=CompileMode.MLA_EV74_A65, automatic_layout_conversion=automatic_layout_conversion, target=self._target) fp32_net = _transform_irmod_to_awesomenet(self._ir_mod, afe_processing_configs, output_labels=self._output_labels, model_path=self._model_path) compiler_passes = _noise_based_mixed_precision_quantization( afe_processing_configs.optimization_configs, accuracy_score, target_accuracy=target_accuracy, max_iterations=max_optimization_steps ) try: quantized_net = compiler_passes(fp32_net, calibration_data, calibration_data, evaluation_data).run() except Exception as e: _sanitize_afe_error("Error occurred during mixed-precision quantization process.", e) assert quantized_net.status == Status.SIMA_QUANTIZED return Model(quantized_net, fp32_net)
@_sanitize_exceptions(ExceptionFuncType.LOADED_NET_CONVERT)
[docs] def convert_to_sima_quantization(self, *, requantization_mode: RequantizationMode = RequantizationMode.sima, model_name: str | None = None, log_level: int = logging.NOTSET) -> Model: with sima_logger.ScopedLogLevel(log_level): # quantization will be skipped, so quantization_config has no effect afe_processing_configs = _create_afe_processing_configs( quantization_config=afe.apis.defines.default_quantization, is_quantized=True, layout=self._layout, name=model_name, enabled_backends=CompileMode.MLA_EV74_A65, requantization_mode=requantization_mode, target=self._target ) net = _transform_irmod_to_awesomenet(self._ir_mod, afe_processing_configs, model_path=self._model_path) assert net.status == Status.SIMA_QUANTIZED return Model(net)
def _require_nonempty_input_types(params: ImporterParams) -> None: if params.input_types is None: raise sima_logger.UserFacingException('Input types are not provided') def _require_nonempty_input_shapes(params: ImporterParams) -> None: if params.input_shapes is None: raise sima_logger.UserFacingException('Input shapes are not provided.') def _require_nonempty_input_names(params: ImporterParams) -> None: if params.input_names is None: raise sima_logger.UserFacingException('Input names are not provided') def _require_nonempty_output_names(params: ImporterParams) -> None: if params.output_names is None: raise sima_logger.UserFacingException('Output names are not provided') @_sanitize_exceptions(ExceptionFuncType.LOADED_NET_LOAD)
[docs] def load_model(params: ImporterParams, *, target: Platform = gen1_target, log_level: int = logging.NOTSET) -> LoadedNet: """ Load a machine learning model into the SiMa Model SDK for further processing such as quantization or compilation. This function validates the input parameters, detects the model format from the provided file paths, and ensures that the required fields (like input shapes, input names, output names) are populated according to the model type. If the model is successfully validated and imported, a `LoadedNet` instance is returned for downstream use. Args: params (ImporterParams): Import parameters including model file paths, input shapes, input types, names, and other configurations. target (Platform, optional): Target platform for which the model should be loaded. Defaults to `gen1_target`. log_level (int, optional): Logging level for the loading process. Defaults to `logging.NOTSET`. Returns: LoadedNet: An object representing the successfully loaded model, ready for quantization, compilation, or other SDK operations. Raises: UserFacingException: - If no model file paths are provided. - If the detected model format does not match the expected format. - If required parameters for the detected model format are missing or invalid. - If the model format is unsupported. Supported Model Formats and Required Parameters: - ONNX, TFLite, Caffe, Caffe2: Requires non-empty `input_types` and `input_shapes`. - PyTorch: Requires non-empty `input_names` and `input_shapes`. - TensorFlow (v1 & v2): Requires non-empty `output_names` and `input_shapes`. - Keras: Requires non-empty `input_shapes`. Example: >>> params = ImporterParams( >>> file_paths=["model.onnx"], >>> input_shapes={"input_1": (1, 3, 224, 224)}, >>> input_types={"input_1": "float32"} >>> ) >>> loaded_model = load_model(params) """ with sima_logger.ScopedLogLevel(log_level): if not params.file_paths: error_message = "Path to model file is not provided." sima_logger.sima_log_error(error_message) raise sima_logger.UserFacingException model_format = detect_format(params.file_paths) if model_format != params.format: error_message = f"Expected {params.format}, got {model_format}" sima_logger.sima_log_error(error_message) raise sima_logger.UserFacingException(error_message) sima_logger.sima_log_info(f"Loading {params.file_paths} in {model_format.value} format") sima_logger.sima_log_dbg(f"ImporterParams = {params.__dict__}") if model_format == ModelFormat.onnx or model_format == ModelFormat.tflite or \ model_format == ModelFormat.caffe or \ model_format == ModelFormat.caffe2: _require_nonempty_input_types(params) _require_nonempty_input_shapes(params) elif model_format == ModelFormat.pytorch: _require_nonempty_input_names(params) _require_nonempty_input_shapes(params) elif model_format == ModelFormat.tensorflow: _require_nonempty_output_names(params) _require_nonempty_input_shapes(params) elif model_format == ModelFormat.tensorflow2: _require_nonempty_output_names(params) _require_nonempty_input_shapes(params) elif model_format == ModelFormat.keras: _require_nonempty_input_shapes(params) else: error_message = f"Model format {model_format} not supported." sima_logger.sima_log_error(error_message) raise sima_logger.UserFacingException(error_message) if len(params.file_paths) == 0: raise sima_logger.UserFacingException("File path(s) must be specified") mod, output_labels = _import_model(params).run() return LoadedNet(mod, params.layout, target, output_labels=output_labels, model_path=params.file_paths[0])