Source code for afe.apis.compilation_job

#########################################################
# Copyright (C) 2022 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Christopher Rodrigues
#########################################################
"""
The API for controlling a compilation job.  This API is
intended for use in a user-defined module that is passed
to the command-line interface.

The recommended usage is to pass a CompilationJob to
set_compilation_job.  Parameters that direct
compilation are incorporated into the CompilationJob.
The job will be read by the command-line driver.
"""
from typing import Generic, List, Tuple, Iterable, Any
import dataclasses

from afe.apis.compilation_job_base import Tensor, Tensors, GroundTruth
from afe.apis.transform import Transform
from afe.apis.statistic import Statistic


@dataclasses.dataclass(frozen=True)
[docs] class CompilationJob(Generic[GroundTruth]): """ A specification of how to calibrate, quantize, evaluate, and compile a model. """ # Preprocessing transforms to apply to each model input.
[docs] preprocess_transforms: List[Transform]
# Postprocessing transforms to apply to each model output.
[docs] postprocess_transforms: List[Transform]
# Input of the model for calibration
[docs] calibration_input: Iterable[Tensors]
# Inputs and ground truth for accuracy evaluation.
[docs] evaluation_input: Iterable[Tuple[Tensors, GroundTruth]]
# Evaluator of result quality. Uses postprocessed tensors and ground truth. # Produces result as a human-readable string.
[docs] evaluate_result: Statistic[Tuple[Tensors, GroundTruth], str]
[docs] def set_compilation_job(job: CompilationJob[Any]) -> None: """ Use the given CompilationJob to control compilation. If called multiple times, the job that is passed to the final call will be used. """ import afe.driver.cli.commands afe.driver.cli.commands.saved_compilation_job = job