Source code for afe.driver.passes

#########################################################
# Copyright (C) 2022 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Christopher Rodrigues
#########################################################
"""
Compiler passes.

These functions wrap IR transforms into compiler passes to be called
by driver or API functions.  Driver code should construct a compilation
pipeline in CompileStep, then run it.
"""
import copy
import functools
import itertools
import dataclasses
import math
from typing import Callable, Iterable, Dict, Optional, Any, List, Tuple, TypeVar, Generic
import os
import numpy as np

from sima_utils.logging import sima_logger

from afe._tvm._defines import TVMIRModule
from afe.apis._sanitize_errors import sanitize_tvm_error
from afe.apis.compilation_job_base import GroundTruth
from afe.apis.defines import InputValues
from afe.backends import Backend
from afe.core.configs import CalibrationConfigs, QuantizationConfigs, OptimizationConfigs, ModelConfigs, \
    QuantizationPrecision
from afe.core.mixed_precision import ranking
from afe.core.utils import dump_yaml_npz, dump_configs_to_yaml, LengthHintedIterable
from afe.core.graph_analyzer import graph_analyzer
import afe.core.graph_analyzer.utils as graph_analyzer_utils
from afe.driver.statistic import Statistic
from afe.driver.compile_step import CompileStep
from afe.ir.execute import create_node_quant_executor
from afe.ir.net import AwesomeNet, NodeName
from afe.ir.defines import BiasCorrectionType
from afe.ir.transform import Calibrate, UpdateQuantizationConfigs, Quantize, InsertNodeObservers, RemoveNodeObservers
from afe.ir.transform.calibration_transforms import InsertPerChannelNodeObservers, CalibratePerChannel
from afe.ir.transform.requantization_fusion import FuseRequantizations
from afe.ir.transform.channel_scaling import get_pairings, pairings_update_pass
from afe.ir.transform.requantization_hoisting.requantization_hoisting import HoistRequantization
from afe.load.importers.general_importer import ImporterParams, import_from_import_params, default_layout, \
    update_with_detected_format

[docs] InputDict = Dict[NodeName, np.ndarray]
_A = TypeVar("_A") # Maximum number of binary search steps for automatic mixed precision quantization
[docs] MIXED_PRECISION_SEARCH_LIMIT = 20
[docs] def import_model(config: ImporterParams) \ -> CompileStep[Tuple[TVMIRModule, Optional[List[str]]]]: """ Create a compiler pass to import a model :param config: configuration for import of a model :return: A compiler step to import a model. It returns the imported TVM module and the module's output names. Output names are only included if the source model has output names. """ config = update_with_detected_format(config) try: mod, output_labels = import_from_import_params(config) except Exception as e: sanitize_tvm_error("Error occurs in importing model.", e) return CompileStep.pure((mod, output_labels))
[docs] def tvm_transformations(*, layout: Optional[str] = 'NCHW', index_to_backend_dict: Optional[Dict[int, Backend]] = None, is_quantized: bool = False, name: str, framework: Optional[str] = None) \ -> Callable[[TVMIRModule], CompileStep[TVMIRModule]]: """ Create a compiler pass to run TVM transformations on a model. The TVM transformations include ConvertLayout. :param layout: Data layout of activation tensors in the input model. :param index_to_backend_dict: Assignment of nodes to backends. Assignments given here override the partitioning algorithm's decision. :param is_quantized: Whether the input is quantized. If quantized, partitioning will decide whether a given operator can execute on MLA. If not quantized, it will decide whether a given operator can be quantized and then execute on MLA. :param name: Name of the model. :return: Compiler pass that transforms Relay IR. """ from afe.load.loader import _transform_and_convert_tvm_ir_to_awesome_net def do_transformations(mod: TVMIRModule) -> CompileStep[TVMIRModule]: mod = _transform_and_convert_tvm_ir_to_awesome_net(mod, name, layout, index_to_backend_dict, is_quantized, framework=framework) return CompileStep.pure(mod) return do_transformations
[docs] def import_and_transform(config: ImporterParams, *, name: str, index_to_backend_dict: Optional[Dict[int, Backend]] = None, is_quantized: bool = False) \ -> CompileStep[TVMIRModule]: """ Create a compiler pass to import and run TVM transformations on a model. See import_model and tvm_transformations for parameter documentation. The returned compile step does the same processing as the old load_*_model functions. """ config = update_with_detected_format(config) layout = config.layout if config.layout is not None else default_layout(config.format) return import_model(config) \ .map(lambda x: x[0]) \ .then(tvm_transformations(name=name, layout=layout, index_to_backend_dict=index_to_backend_dict, is_quantized=is_quantized, framework=config.format.value))
[docs] def update_quantization_configs(quantization_config: QuantizationConfigs, *, custom_quantization_configs: Optional[Dict[NodeName, Dict[str, Any]]] = None) \ -> Callable[[AwesomeNet], CompileStep[AwesomeNet]]: """ Create a compiler pass that records quantization parameters on each node in a network. The compiler pass modifies the network. :param quantization_config: Global configuration parameters for quantization. These will be inserted on all nodes, but will not override previously inserted parameters. :param custom_quantization_configs: Dictionary to override quantization settings for specific nodes. This parameter may only be used in tests. Where custom_quantization_configs[node_name][field_name] = value, it will set the given node's given QuantizationConfigs field to the given value. For example, passing the value {"MLA_1/conv2d_add_84": {"output_int32": True}} will override the configuration of the node named "MLA_1/conv2d_add_84" by setting its output_int32 field to True. :return: Compiler pass to update parameters. The pass mutates and returns its input. """ def do_update(net: AwesomeNet) -> CompileStep[AwesomeNet]: # Copy information from config and custom_quantization_configs into graph nodes. It will be # read by quantization. Mutates net. UpdateQuantizationConfigs(quantization_config, custom_quantization_configs)(net) return CompileStep.pure(net) return do_update
[docs] def calibration(calibration_config: CalibrationConfigs) \ -> Callable[[AwesomeNet, Iterable[InputDict]], CompileStep[AwesomeNet]]: """ Create a compiler pass to calibrate a network. :param calibration_config: Configuration for calibration. :return: A compiler pass to calibrate a network. The pass mutates and returns its input. """ def do_calibrate(net: AwesomeNet, inputs: Iterable[InputDict]) \ -> CompileStep[AwesomeNet]: # Insert observers calculating activation statistics. InsertNodeObservers(calibration_method=calibration_config.calibration_method)(net) if calibration_config.num_calibration_samples is not None: # Restrict the number of inputs used for calibration inputs = itertools.islice(inputs, calibration_config.num_calibration_samples) length_hint = calibration_config.num_calibration_samples elif isinstance(inputs, LengthHintedIterable): length_hint = inputs.get_length() elif isinstance(inputs, List): length_hint = len(inputs) else: length_hint = None calibrator = Calibrate(length_hint=length_hint, dataset=inputs) calibrator(net) # Mutates net return CompileStep.pure(net) return do_calibrate
[docs] def quantization(input_dataset: Optional[Iterable[InputDict]]) -> Callable[[AwesomeNet], CompileStep[AwesomeNet]]: """ Create a compiler pass to quantize a network. Quantization configuration is set by UpdateQuantizationConfigs transform the calibration pass. :return: A compiler pass to quantize a network. The pass mutates and returns its input. """ def do_quantize(net: AwesomeNet) -> CompileStep[AwesomeNet]: # Run quantization. Mutates net. Quantize()(net, input_dataset) # Remove NodeObservers in order to remove PyTorch infrastructure from AwesomeNet RemoveNodeObservers()(net) # Move Requantize nodes as early as possible HoistRequantization()(net) # Simplify Requantize nodes that were inserted by quantization FuseRequantizations()(net) return CompileStep.pure(net) return do_quantize
[docs] def equalization(calibration_config: CalibrationConfigs, quantization_config: QuantizationConfigs) \ -> Callable[[AwesomeNet, Iterable[InputDict]], CompileStep[AwesomeNet]]: """ Run SmoothQuant and/or channel equalization if they are enabled in config. This pass should run before quantization. """ # If none of these optimizations will run, return a pass that does nothing if not quantization_config.channel_equalization.get() and not quantization_config.smooth_quant.get(): return lambda net, input_dataset: CompileStep.pure(net) def do_passes(net: AwesomeNet, input_dataset: Iterable[InputDict]) -> CompileStep[AwesomeNet]: InsertPerChannelNodeObservers()(net) if calibration_config.num_calibration_samples is not None: # Restrict the number of inputs used for calibration input_dataset = itertools.islice(input_dataset, calibration_config.num_calibration_samples) length_hint = calibration_config.num_calibration_samples elif isinstance(input_dataset, LengthHintedIterable): length_hint = input_dataset.get_length() else: length_hint = None CalibratePerChannel(length_hint, input_dataset)(net) # Run pass to find node-pairs to scale, and a second pass to compute scales and update weights # using channel equalization pairings_lists = get_pairings(net) pairings_update_pass(pairings_lists) RemoveNodeObservers()(net) return CompileStep.pure(net) return do_passes
[docs] def calibration_quantization(config: OptimizationConfigs, *, custom_quantization_configs: Optional[Dict[NodeName, Dict[str, Any]]] = None) \ -> Callable[[AwesomeNet, Iterable[InputDict]], CompileStep[AwesomeNet]]: """ Run calibration and quantization. Quantization-related optimizations that can run at the same time are included here. :param config: Parameters for calibration and quantization. :param custom_quantization_configs: Dictionary to override quantization settings for specific nodes. This parameter may only be used in tests. :return: A compiler pass to calibrate and quantize a network. The pass mutates and returns its input. """ def do_passes(net: AwesomeNet, input_dataset: Iterable[InputDict]) \ -> CompileStep[AwesomeNet]: return update_quantization_configs(config.quantization_configs, custom_quantization_configs=custom_quantization_configs)(net) \ .then(lambda net1: equalization(config.calibration_configs, config.quantization_configs)(net1, input_dataset)) \ .then(lambda net2: calibration(config.calibration_configs)(net2, input_dataset)) \ .then(quantization(input_dataset if (config.quantization_configs.biascorr_type.get() == BiasCorrectionType.ITERATIVE) else None)) return do_passes
[docs] def evaluation( criterion: Statistic[Tuple[List[np.ndarray], GroundTruth], _A], *, fast_mode: bool = False ) -> Callable[[AwesomeNet, Iterable[Tuple[InputValues, GroundTruth]]], CompileStep[_A]]: """ Execute a model on an input set and compute an aggregate result from the model's outputs. The primary use case of this function is to estimate a model's accuracy. In this case, criterion computes an accuracy metric for the model over the data set. :param criterion: Function to compute on the model's output and auxiliary data. :param fast_mode: Whether to execute in fast mode. :return: A compiler pass that takes an AwesomeNet and data source, runs the model, and returns the function's result. """ def do_evaluation(net: AwesomeNet, evaluation_data: Iterable[Tuple[InputValues, GroundTruth]]) \ -> CompileStep[_A]: node_callable = create_node_quant_executor(fast_mode=fast_mode) accumulator = criterion.initialize() for inputs_and_ground_truth in evaluation_data: inputs, gr_truth = inputs_and_ground_truth output = net.run(inputs=inputs, node_callable=node_callable) criterion.update(accumulator, (output, gr_truth)) evaluation_result = criterion.finish(accumulator) return CompileStep.pure(evaluation_result) return do_evaluation
@dataclasses.dataclass
[docs] class BinarySearchState(Generic[_A]): """ State of binary search. :param lo: Low bound of search range. This is the highest index that was found to not satisfy the search condition. :param hi: High bound of search range. This is the lowest index that was found to satisfy the search condition. :param hi_value: Value associated with high bound. This will be returned if the high value is selected when the search finishes. :param iteration: Iteration of search, starting from zero. Used for deciding to stop early. """
[docs] lo: int
[docs] hi: int
[docs] hi_value: _A
[docs] iteration: int
def _annotate_quantization_for_mixed_precision( net: AwesomeNet, node_ranking: list[list[NodeName]], cutoff: int, ) -> CompileStep[AwesomeNet]: """ Annotate a network with one set of parameters for mixed precision search. High precision is used for nodes in ranking[:cutoff]. :param net: Net to quantize. This function does not modify it. :param node_ranking: Ranking of nodes, ordered from high to low sensitivity. :param cutoff: Index of first low-precision node in ranking. Indices before this use high precision. :return: New network with high-precision nodes annotated. """ high_precision_nodes = {n for n_set in node_ranking[:cutoff] for n in n_set} def do_annotate(): # Make an annotated copy of net. net2 = copy.deepcopy(net) ranking.annotate_int16_nodes(net2, high_precision_nodes) return net2 return CompileStep.from_thunk(do_annotate) def _single_mixed_precision_iteration( net: AwesomeNet, node_ranking: list[list[NodeName]], threshold: float, quantization_pass: Callable[[AwesomeNet], CompileStep[AwesomeNet]], evaluation_pass: Callable[[AwesomeNet], CompileStep[float]], ) -> Callable[[int], CompileStep[tuple[bool, AwesomeNet]]]: """ Run one iteration of quantization for mixed precision. :param net: Network to quantize. This function does not modify it. :param node_ranking: Ranking of nodes from high to low sensitivity. Nodes having equal sensitivity reside together in the inner list. :param threshold: Minimum acceptable quantized accuracy :param quantization_pass: Compiler pass that quantizes the network :param evaluation_pass: Compiler pass that evaluates accuracy of the network :return: Iteration function for binary search. Takes a cutoff value and returns a tuple of (whether the quantized accuracy is good, the quantized network). """ def make_result(cutoff: int, acc: float, qnet: AwesomeNet) -> tuple[bool, AwesomeNet]: sima_logger.sima_log_info( f"Model quantized with {cutoff} sensitivity classes in 16-bit precision has accuracy {acc}" ) return acc >= threshold, qnet def do_passes(cutoff: int): return _annotate_quantization_for_mixed_precision(net, node_ranking, cutoff) \ .then(quantization_pass) \ .then(lambda qnet: evaluation_pass(qnet).map(lambda accuracy: make_result(cutoff, accuracy, qnet))) return do_passes
[docs] def noise_analysis() \ -> Callable[[AwesomeNet, AwesomeNet, Iterable[InputDict]], CompileStep[Dict[graph_analyzer_utils.Metric, graph_analyzer.AnalyzedResultDict]]]: """ Analyze noise that is introduced by quantization. :return: A compiler pass to analyze noise. The pass takes as parameters an un-quantized net, a quantized net derived from it, and the evaluation input. It executes both nets on the evaluation inputs and compares the values at each layer to estimate quantization noise. It returns the analysis results. """ mode = graph_analyzer_utils.QuantizedGraphAnalyzerMode.local_feed def do_passes(fp_net: AwesomeNet, quant_net: AwesomeNet, input_dataset: Iterable[InputDict]) \ -> CompileStep[Dict[graph_analyzer_utils.Metric, graph_analyzer.AnalyzedResultDict]]: analyzer = graph_analyzer.QuantizedGraphAnalyzer(mode=mode) analyzer.analyze(fp_net, quant_net, input_dataset, graph_analyzer_utils.Metric.psnr) results = analyzer.analyzed_results return CompileStep.pure(results) return do_passes
[docs] def noise_based_mixed_precision_quantization( config: OptimizationConfigs, criterion: Statistic[Tuple[List[np.ndarray], GroundTruth], float], *, target_accuracy: float, custom_quantization_configs: Optional[Dict[NodeName, Dict[str, Any]]] = None, max_iterations: int = MIXED_PRECISION_SEARCH_LIMIT, fast_mode: bool = True ) -> Callable[[AwesomeNet, Iterable[InputDict], Iterable[InputDict], Iterable[tuple[InputValues, GroundTruth]]], CompileStep[AwesomeNet]]: """ Do mixed-precision quantization using noise analysis to choose precision. It will first quantize the model with int8 precision and measure its quantization noise on the analysis dataset. Then it will try to minimize the number of int16 nodes that achieve the target accuracy. It raises an exception if it cannot achieve the target accuracy. A node achieves the target accuracy if evaluating the node using 'criterion' returns a number that is at least 'target_accuracy'. :param config: Parameters for calibration and quantization. The quantization precision must be int8. :param criterion: Method of evaluating accuracy on a data set. :param target_accuracy: Desired accuracy of network. :param custom_quantization_configs: Dictionary to override quantization settings for specific nodes. This parameter may only be used in tests. :param max_iterations: Maximum number of binary search steps to perform. :param fast_mode: Whether to use fast mode when executing the network. :return: A compiler pass that does mixed-precision quantization on a floating-point model, calibration dataset, analysis dataset, and evaluation dataset. """ assert config.quantization_configs.quantization_precision.get() == QuantizationPrecision.INT_8, \ "For mixed-precision quantization, the optimization configuration must use int8 precision" quantization_pass = calibration_quantization(config, custom_quantization_configs=custom_quantization_configs) evaluation_pass = evaluation(criterion, fast_mode=fast_mode) def get_ranking(stats: Dict[graph_analyzer_utils.Metric, graph_analyzer.AnalyzedResultDict]) \ -> list[list[NodeName]]: # Make a sensitivity ranking from the stats returned by noise analysis. assert len(stats) == 1, "Analysis must use a single metric" try: stat = stats['psnr'] except KeyError: raise ValueError("Expected analysis results from psnr metric") # Convert the calculated stat to a sensitivity estimate for each node. stat = {k: 1/(np.mean(v) + 1e-10) for k, v in stat.items()} ret = ranking.rank_int16_nodes(stat) sima_logger.sima_log_info( f"Model graph nodes were put into {len(ret)} sensitivity classes for mixed precision quantization" ) return ret def do_binary_search(fp_net: AwesomeNet, calibration_dataset: Iterable[InputDict], evaluation_dataset: Iterable[tuple[InputValues, GroundTruth]], node_ranking: list[list[NodeName]]) -> CompileStep[Optional[AwesomeNet]]: return binary_search( _single_mixed_precision_iteration( fp_net, node_ranking, target_accuracy, lambda net: quantization_pass(net, calibration_dataset), lambda net: evaluation_pass(net, evaluation_dataset) ), 0, len(node_ranking), max_iterations, procedure_name="mixed-precision search" ) def check_search_results(search_results: Optional[AwesomeNet]) -> CompileStep[AwesomeNet]: if search_results is None: raise sima_logger.UserFacingException( "Mixed precision quantization could not reach the target accuracy." ) int16_count, node_count = ranking.count_int16_nodes(search_results) print(f"Completed mixed-precision search with {int16_count} of {node_count} nodes using 16 bits") return CompileStep.pure(search_results) def do_passes(fp_net: AwesomeNet, calibration_dataset: Iterable[InputDict], analysis_dataset: Iterable[InputDict], evaluation_dataset: Iterable[tuple[InputValues, GroundTruth]]): # Quantize with int8. Analyze quantization noise. Make sensitivity ranking based on noise. # Do a binary search on how many nodes to quantize in int16. # Any step that uses fp_net will make a copy first so that fp_net is unchanged. return CompileStep.from_thunk(lambda: copy.deepcopy(fp_net)) \ .then(lambda net: quantization_pass(net, calibration_dataset)) \ .then(lambda net: noise_analysis()(fp_net, net, analysis_dataset)) \ .map(get_ranking) \ .then(lambda node_ranking: do_binary_search(fp_net, calibration_dataset, evaluation_dataset, node_ranking)) \ .then(check_search_results) return do_passes
[docs] def dump_diagnostic_files(model_config: ModelConfigs, opt_config: OptimizationConfigs, *, prefix: str = "", suffix: str = "") \ -> Callable[[AwesomeNet], CompileStep[None]]: """ Save intermediate compilation results to files for diagnostic purposes. The following files are saved in model_config.output_directory. The model is saved to {prefix}{model_config.name}{suffix}.yaml and {prefix}{model_config.name}{suffix}npz. The configuration parameters are saved to {model_config.model_name}.yaml. :param model_config: Model testing configuration :param opt_config: Optimization parameters :param prefix: Prefix to attach to filenames :param suffix: Suffix to attach to filenames :return: A compiler pass that dumps diagnostic files """ def run_dump(net: AwesomeNet): os.makedirs(model_config.output_directory, exist_ok=True) dump_yaml_npz(model_config, net, name_prefix=prefix, name_postfix=suffix) dump_configs_to_yaml(model_config, opt_config) return CompileStep.pure(None) return run_dump
[docs] def dump_diagnostic_files_after(step: CompileStep[AwesomeNet], model_config: ModelConfigs, opt_config: OptimizationConfigs, *, condition: bool = True, prefix: str = "", suffix: str = "") \ -> CompileStep[AwesomeNet]: """ Run a compiler pass, then save intermediate compilation results to files for diagnostic purposes. See dump_diagnostic_files for details. :param step: Compilation step to run before dumping results :param model_config: Model testing configuration :param opt_config: Optimization parameters :param prefix: Prefix to attach to filenames :param suffix: Suffix to attach to filenames :param condition: Whether to dump files. If False, only the compilation step runs. :return: Compilation step extended to dump files at the end """ if not condition: return step dump = dump_diagnostic_files(model_config, opt_config, prefix=prefix, suffix=suffix) return (step .then(lambda net: dump(net) .then(lambda _: CompileStep.pure(net))))