#########################################################
# Copyright (C) 2022 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Christopher Rodrigues
#########################################################
"""
Compiler passes.
These functions wrap IR transforms into compiler passes to be called
by driver or API functions. Driver code should construct a compilation
pipeline in CompileStep, then run it.
"""
import copy
import functools
import itertools
import dataclasses
import math
from typing import Callable, Iterable, Dict, Optional, Any, List, Tuple, TypeVar, Generic
import os
import numpy as np
from sima_utils.logging import sima_logger
from afe._tvm._defines import TVMIRModule
from afe.apis._sanitize_errors import sanitize_tvm_error
from afe.apis.compilation_job_base import GroundTruth
from afe.apis.defines import InputValues
from afe.backends import Backend
from afe.core.configs import CalibrationConfigs, QuantizationConfigs, OptimizationConfigs, ModelConfigs, \
QuantizationPrecision
from afe.core.mixed_precision import ranking
from afe.core.utils import dump_yaml_npz, dump_configs_to_yaml, LengthHintedIterable
from afe.core.graph_analyzer import graph_analyzer
import afe.core.graph_analyzer.utils as graph_analyzer_utils
from afe.driver.statistic import Statistic
from afe.driver.compile_step import CompileStep
from afe.ir.execute import create_node_quant_executor
from afe.ir.net import AwesomeNet, NodeName
from afe.ir.defines import BiasCorrectionType
from afe.ir.transform import Calibrate, UpdateQuantizationConfigs, Quantize, InsertNodeObservers, RemoveNodeObservers
from afe.ir.transform.calibration_transforms import InsertPerChannelNodeObservers, CalibratePerChannel
from afe.ir.transform.requantization_fusion import FuseRequantizations
from afe.ir.transform.channel_scaling import get_pairings, pairings_update_pass
from afe.ir.transform.requantization_hoisting.requantization_hoisting import HoistRequantization
from afe.load.importers.general_importer import ImporterParams, import_from_import_params, default_layout, \
update_with_detected_format
_A = TypeVar("_A")
# Maximum number of binary search steps for automatic mixed precision quantization
[docs]
MIXED_PRECISION_SEARCH_LIMIT = 20
[docs]
def import_model(config: ImporterParams) \
-> CompileStep[Tuple[TVMIRModule, Optional[List[str]]]]:
"""
Create a compiler pass to import a model
:param config: configuration for import of a model
:return: A compiler step to import a model.
It returns the imported TVM module and the module's output names. Output names are only included
if the source model has output names.
"""
config = update_with_detected_format(config)
try:
mod, output_labels = import_from_import_params(config)
except Exception as e:
sanitize_tvm_error("Error occurs in importing model.", e)
return CompileStep.pure((mod, output_labels))
[docs]
def update_quantization_configs(quantization_config: QuantizationConfigs, *,
custom_quantization_configs: Optional[Dict[NodeName, Dict[str, Any]]] = None) \
-> Callable[[AwesomeNet], CompileStep[AwesomeNet]]:
"""
Create a compiler pass that records quantization parameters on each node in a network.
The compiler pass modifies the network.
:param quantization_config: Global configuration parameters for quantization. These will be inserted
on all nodes, but will not override previously inserted parameters.
:param custom_quantization_configs: Dictionary to override quantization settings for specific nodes.
This parameter may only be used in tests.
Where custom_quantization_configs[node_name][field_name] = value, it will set the given node's
given QuantizationConfigs field to the given value. For example, passing the value
{"MLA_1/conv2d_add_84": {"output_int32": True}} will override the configuration of the node
named "MLA_1/conv2d_add_84" by setting its output_int32 field to True.
:return: Compiler pass to update parameters. The pass mutates and returns its input.
"""
def do_update(net: AwesomeNet) -> CompileStep[AwesomeNet]:
# Copy information from config and custom_quantization_configs into graph nodes. It will be
# read by quantization. Mutates net.
UpdateQuantizationConfigs(quantization_config, custom_quantization_configs)(net)
return CompileStep.pure(net)
return do_update
[docs]
def calibration(calibration_config: CalibrationConfigs) \
-> Callable[[AwesomeNet, Iterable[InputDict]], CompileStep[AwesomeNet]]:
"""
Create a compiler pass to calibrate a network.
:param calibration_config: Configuration for calibration.
:return: A compiler pass to calibrate a network. The pass mutates and returns its input.
"""
def do_calibrate(net: AwesomeNet, inputs: Iterable[InputDict]) \
-> CompileStep[AwesomeNet]:
# Insert observers calculating activation statistics.
InsertNodeObservers(calibration_method=calibration_config.calibration_method)(net)
if calibration_config.num_calibration_samples is not None:
# Restrict the number of inputs used for calibration
inputs = itertools.islice(inputs, calibration_config.num_calibration_samples)
length_hint = calibration_config.num_calibration_samples
elif isinstance(inputs, LengthHintedIterable):
length_hint = inputs.get_length()
elif isinstance(inputs, List):
length_hint = len(inputs)
else:
length_hint = None
calibrator = Calibrate(length_hint=length_hint, dataset=inputs)
calibrator(net) # Mutates net
return CompileStep.pure(net)
return do_calibrate
[docs]
def quantization(input_dataset: Optional[Iterable[InputDict]]) -> Callable[[AwesomeNet], CompileStep[AwesomeNet]]:
"""
Create a compiler pass to quantize a network. Quantization configuration is set by
UpdateQuantizationConfigs transform the calibration pass.
:return: A compiler pass to quantize a network. The pass mutates and returns its input.
"""
def do_quantize(net: AwesomeNet) -> CompileStep[AwesomeNet]:
# Run quantization. Mutates net.
Quantize()(net, input_dataset)
# Remove NodeObservers in order to remove PyTorch infrastructure from AwesomeNet
RemoveNodeObservers()(net)
# Move Requantize nodes as early as possible
HoistRequantization()(net)
# Simplify Requantize nodes that were inserted by quantization
FuseRequantizations()(net)
return CompileStep.pure(net)
return do_quantize
[docs]
def equalization(calibration_config: CalibrationConfigs, quantization_config: QuantizationConfigs) \
-> Callable[[AwesomeNet, Iterable[InputDict]], CompileStep[AwesomeNet]]:
"""
Run SmoothQuant and/or channel equalization if they are enabled in config.
This pass should run before quantization.
"""
# If none of these optimizations will run, return a pass that does nothing
if not quantization_config.channel_equalization.get() and not quantization_config.smooth_quant.get():
return lambda net, input_dataset: CompileStep.pure(net)
def do_passes(net: AwesomeNet, input_dataset: Iterable[InputDict]) -> CompileStep[AwesomeNet]:
InsertPerChannelNodeObservers()(net)
if calibration_config.num_calibration_samples is not None:
# Restrict the number of inputs used for calibration
input_dataset = itertools.islice(input_dataset, calibration_config.num_calibration_samples)
length_hint = calibration_config.num_calibration_samples
elif isinstance(input_dataset, LengthHintedIterable):
length_hint = input_dataset.get_length()
else:
length_hint = None
CalibratePerChannel(length_hint, input_dataset)(net)
# Run pass to find node-pairs to scale, and a second pass to compute scales and update weights
# using channel equalization
pairings_lists = get_pairings(net)
pairings_update_pass(pairings_lists)
RemoveNodeObservers()(net)
return CompileStep.pure(net)
return do_passes
[docs]
def calibration_quantization(config: OptimizationConfigs, *,
custom_quantization_configs: Optional[Dict[NodeName, Dict[str, Any]]] = None) \
-> Callable[[AwesomeNet, Iterable[InputDict]], CompileStep[AwesomeNet]]:
"""
Run calibration and quantization. Quantization-related optimizations that can run at the same time
are included here.
:param config: Parameters for calibration and quantization.
:param custom_quantization_configs: Dictionary to override quantization settings for specific nodes.
This parameter may only be used in tests.
:return: A compiler pass to calibrate and quantize a network. The pass mutates and returns its input.
"""
def do_passes(net: AwesomeNet, input_dataset: Iterable[InputDict]) \
-> CompileStep[AwesomeNet]:
return update_quantization_configs(config.quantization_configs,
custom_quantization_configs=custom_quantization_configs)(net) \
.then(lambda net1: equalization(config.calibration_configs, config.quantization_configs)(net1, input_dataset)) \
.then(lambda net2: calibration(config.calibration_configs)(net2, input_dataset)) \
.then(quantization(input_dataset if (config.quantization_configs.biascorr_type.get() == BiasCorrectionType.ITERATIVE) else None))
return do_passes
[docs]
def evaluation(
criterion: Statistic[Tuple[List[np.ndarray], GroundTruth], _A], *,
fast_mode: bool = False
) -> Callable[[AwesomeNet, Iterable[Tuple[InputValues, GroundTruth]]], CompileStep[_A]]:
"""
Execute a model on an input set and compute an aggregate result from the model's outputs.
The primary use case of this function is to estimate a model's accuracy. In this case,
criterion computes an accuracy metric for the model over the data set.
:param criterion: Function to compute on the model's output and auxiliary data.
:param fast_mode: Whether to execute in fast mode.
:return: A compiler pass that takes an AwesomeNet and data source, runs the model,
and returns the function's result.
"""
def do_evaluation(net: AwesomeNet, evaluation_data: Iterable[Tuple[InputValues, GroundTruth]]) \
-> CompileStep[_A]:
node_callable = create_node_quant_executor(fast_mode=fast_mode)
accumulator = criterion.initialize()
for inputs_and_ground_truth in evaluation_data:
inputs, gr_truth = inputs_and_ground_truth
output = net.run(inputs=inputs, node_callable=node_callable)
criterion.update(accumulator, (output, gr_truth))
evaluation_result = criterion.finish(accumulator)
return CompileStep.pure(evaluation_result)
return do_evaluation
@dataclasses.dataclass
[docs]
class BinarySearchState(Generic[_A]):
"""
State of binary search.
:param lo: Low bound of search range. This is the highest index that was found to not satisfy the search condition.
:param hi: High bound of search range. This is the lowest index that was found to satisfy the search condition.
:param hi_value: Value associated with high bound. This will be returned if the high value is selected when the
search finishes.
:param iteration: Iteration of search, starting from zero. Used for deciding to stop early.
"""
[docs]
def binary_search(get_result: Callable[[int], CompileStep[tuple[bool, _A]]], lo: int, hi: int, limit: int, *,
procedure_name: Optional[str] = None) \
-> CompileStep[Optional[_A]]:
"""
Do a binary search for the smallest integer n in the range [lo, hi] such that get_result(n)
returns True in its bool result. Return the second result of get_result for the best n that was found,
or None if every call returned False. The search always tests lo and hi, so if one of these
returns True then the search will find a satisfactory n.
The search assumes that get_result is monotonic, that is, there's some n such that
get_result(i) returns False for all i <= n and returns True for all i > n. If it is
not monotonic, it may not find the optimal n.
:param get_result: How to evaluate the search at a given value of n. When it runs, it returns
a success flag and caller-specific data.
:param lo: Lowest value of n to evaluate
:param hi: Highest value of n to evaluate
:param limit: Maximum number of binary search steps to perform
:param procedure_name: Name of the procedure to be printed in progress messages to the console.
If None, do not print progress messages.
:return:
"""
assert hi > lo
max_iterations = min(math.ceil(math.log2(hi - lo)), limit) # Upper bound on number of iterations
def update_state(state: BinarySearchState, mid: int, result: tuple[bool, _A]) -> BinarySearchState:
# Update the search state with result of one evaluation. Evaluation of 'mid' has returned 'result'.
good, value = result
if good:
# Replace hi with mid
return BinarySearchState(state.lo, mid, value, state.iteration + 1)
else:
# Replace lo with mid
return BinarySearchState(mid, state.hi, state.hi_value, state.iteration + 1)
def continue_search(state: BinarySearchState) -> CompileStep[Optional[_A]]:
# Main loop of the search.
# If search has completed or if search stops early, return the best value that was found
if state.hi - state.lo == 1 or state.iteration >= limit:
return CompileStep.pure(state.hi_value)
if procedure_name is not None:
print(f"Evaluating {procedure_name} step {state.iteration + 1} out of at most {max_iterations}")
mid = state.lo + (state.hi - state.lo) // 2
return get_result(mid).then(lambda result: continue_search(update_state(state, mid, result)))
def start_search(lo_result: tuple[bool, _A], hi_result: tuple[bool, _A]) -> CompileStep[Optional[_A]]:
# Check search results from endpoints of the range. If endpoints are as expected,
# begin the main loop.
lo_good, lo_value = lo_result
hi_good, hi_value = hi_result
if lo_good:
# Lowest possible value satisfies the search condition. Stop here.
return CompileStep.pure(lo_value)
if not hi_good:
# Highest possible value does not satisfy the search condition. Search fails.
return CompileStep.pure(None)
return continue_search(BinarySearchState(lo, hi, hi_value, 0))
# Start the process. Evaluate the function at lo and hi, then proceed to start_search.
if procedure_name is not None:
print(f"Evaluating {procedure_name} at minimum and maximum values")
return get_result(lo).then(lambda l: get_result(hi).then(lambda h: start_search(l, h)))
def _annotate_quantization_for_mixed_precision(
net: AwesomeNet,
node_ranking: list[list[NodeName]],
cutoff: int,
) -> CompileStep[AwesomeNet]:
"""
Annotate a network with one set of parameters for mixed precision search.
High precision is used for nodes in ranking[:cutoff].
:param net: Net to quantize. This function does not modify it.
:param node_ranking: Ranking of nodes, ordered from high to low sensitivity.
:param cutoff: Index of first low-precision node in ranking. Indices before
this use high precision.
:return: New network with high-precision nodes annotated.
"""
high_precision_nodes = {n for n_set in node_ranking[:cutoff] for n in n_set}
def do_annotate():
# Make an annotated copy of net.
net2 = copy.deepcopy(net)
ranking.annotate_int16_nodes(net2, high_precision_nodes)
return net2
return CompileStep.from_thunk(do_annotate)
def _single_mixed_precision_iteration(
net: AwesomeNet,
node_ranking: list[list[NodeName]],
threshold: float,
quantization_pass: Callable[[AwesomeNet], CompileStep[AwesomeNet]],
evaluation_pass: Callable[[AwesomeNet], CompileStep[float]],
) -> Callable[[int], CompileStep[tuple[bool, AwesomeNet]]]:
"""
Run one iteration of quantization for mixed precision.
:param net: Network to quantize. This function does not modify it.
:param node_ranking: Ranking of nodes from high to low sensitivity. Nodes having
equal sensitivity reside together in the inner list.
:param threshold: Minimum acceptable quantized accuracy
:param quantization_pass: Compiler pass that quantizes the network
:param evaluation_pass: Compiler pass that evaluates accuracy of the network
:return: Iteration function for binary search. Takes a cutoff value and
returns a tuple of (whether the quantized accuracy is good, the quantized network).
"""
def make_result(cutoff: int, acc: float, qnet: AwesomeNet) -> tuple[bool, AwesomeNet]:
sima_logger.sima_log_info(
f"Model quantized with {cutoff} sensitivity classes in 16-bit precision has accuracy {acc}"
)
return acc >= threshold, qnet
def do_passes(cutoff: int):
return _annotate_quantization_for_mixed_precision(net, node_ranking, cutoff) \
.then(quantization_pass) \
.then(lambda qnet: evaluation_pass(qnet).map(lambda accuracy: make_result(cutoff, accuracy, qnet)))
return do_passes
[docs]
def noise_analysis() \
-> Callable[[AwesomeNet, AwesomeNet, Iterable[InputDict]],
CompileStep[Dict[graph_analyzer_utils.Metric, graph_analyzer.AnalyzedResultDict]]]:
"""
Analyze noise that is introduced by quantization.
:return: A compiler pass to analyze noise. The pass takes as parameters an un-quantized net,
a quantized net derived from it, and the evaluation input. It executes both nets on the
evaluation inputs and compares the values at each layer to estimate quantization noise.
It returns the analysis results.
"""
mode = graph_analyzer_utils.QuantizedGraphAnalyzerMode.local_feed
def do_passes(fp_net: AwesomeNet, quant_net: AwesomeNet, input_dataset: Iterable[InputDict]) \
-> CompileStep[Dict[graph_analyzer_utils.Metric, graph_analyzer.AnalyzedResultDict]]:
analyzer = graph_analyzer.QuantizedGraphAnalyzer(mode=mode)
analyzer.analyze(fp_net, quant_net, input_dataset, graph_analyzer_utils.Metric.psnr)
results = analyzer.analyzed_results
return CompileStep.pure(results)
return do_passes
[docs]
def noise_based_mixed_precision_quantization(
config: OptimizationConfigs,
criterion: Statistic[Tuple[List[np.ndarray], GroundTruth], float], *,
target_accuracy: float,
custom_quantization_configs: Optional[Dict[NodeName, Dict[str, Any]]] = None,
max_iterations: int = MIXED_PRECISION_SEARCH_LIMIT,
fast_mode: bool = True
) -> Callable[[AwesomeNet, Iterable[InputDict], Iterable[InputDict], Iterable[tuple[InputValues, GroundTruth]]],
CompileStep[AwesomeNet]]:
"""
Do mixed-precision quantization using noise analysis to choose precision.
It will first quantize the model with int8 precision and measure its quantization noise on the analysis dataset.
Then it will try to minimize the number of int16 nodes that achieve the target accuracy.
It raises an exception if it cannot achieve the target accuracy.
A node achieves the target accuracy if evaluating the node using 'criterion' returns
a number that is at least 'target_accuracy'.
:param config: Parameters for calibration and quantization. The quantization precision
must be int8.
:param criterion: Method of evaluating accuracy on a data set.
:param target_accuracy: Desired accuracy of network.
:param custom_quantization_configs: Dictionary to override quantization settings for specific nodes.
This parameter may only be used in tests.
:param max_iterations: Maximum number of binary search steps to perform.
:param fast_mode: Whether to use fast mode when executing the network.
:return: A compiler pass that does mixed-precision quantization on a floating-point model,
calibration dataset, analysis dataset, and evaluation dataset.
"""
assert config.quantization_configs.quantization_precision.get() == QuantizationPrecision.INT_8, \
"For mixed-precision quantization, the optimization configuration must use int8 precision"
quantization_pass = calibration_quantization(config, custom_quantization_configs=custom_quantization_configs)
evaluation_pass = evaluation(criterion, fast_mode=fast_mode)
def get_ranking(stats: Dict[graph_analyzer_utils.Metric, graph_analyzer.AnalyzedResultDict]) \
-> list[list[NodeName]]:
# Make a sensitivity ranking from the stats returned by noise analysis.
assert len(stats) == 1, "Analysis must use a single metric"
try:
stat = stats['psnr']
except KeyError:
raise ValueError("Expected analysis results from psnr metric")
# Convert the calculated stat to a sensitivity estimate for each node.
stat = {k: 1/(np.mean(v) + 1e-10) for k, v in stat.items()}
ret = ranking.rank_int16_nodes(stat)
sima_logger.sima_log_info(
f"Model graph nodes were put into {len(ret)} sensitivity classes for mixed precision quantization"
)
return ret
def do_binary_search(fp_net: AwesomeNet,
calibration_dataset: Iterable[InputDict],
evaluation_dataset: Iterable[tuple[InputValues, GroundTruth]],
node_ranking: list[list[NodeName]]) -> CompileStep[Optional[AwesomeNet]]:
return binary_search(
_single_mixed_precision_iteration(
fp_net, node_ranking, target_accuracy,
lambda net: quantization_pass(net, calibration_dataset),
lambda net: evaluation_pass(net, evaluation_dataset)
),
0, len(node_ranking), max_iterations,
procedure_name="mixed-precision search"
)
def check_search_results(search_results: Optional[AwesomeNet]) -> CompileStep[AwesomeNet]:
if search_results is None:
raise sima_logger.UserFacingException(
"Mixed precision quantization could not reach the target accuracy."
)
int16_count, node_count = ranking.count_int16_nodes(search_results)
print(f"Completed mixed-precision search with {int16_count} of {node_count} nodes using 16 bits")
return CompileStep.pure(search_results)
def do_passes(fp_net: AwesomeNet,
calibration_dataset: Iterable[InputDict],
analysis_dataset: Iterable[InputDict],
evaluation_dataset: Iterable[tuple[InputValues, GroundTruth]]):
# Quantize with int8. Analyze quantization noise. Make sensitivity ranking based on noise.
# Do a binary search on how many nodes to quantize in int16.
# Any step that uses fp_net will make a copy first so that fp_net is unchanged.
return CompileStep.from_thunk(lambda: copy.deepcopy(fp_net)) \
.then(lambda net: quantization_pass(net, calibration_dataset)) \
.then(lambda net: noise_analysis()(fp_net, net, analysis_dataset)) \
.map(get_ranking) \
.then(lambda node_ranking: do_binary_search(fp_net, calibration_dataset, evaluation_dataset, node_ranking)) \
.then(check_search_results)
return do_passes
[docs]
def dump_diagnostic_files(model_config: ModelConfigs,
opt_config: OptimizationConfigs,
*, prefix: str = "", suffix: str = "") \
-> Callable[[AwesomeNet], CompileStep[None]]:
"""
Save intermediate compilation results to files for diagnostic purposes.
The following files are saved in model_config.output_directory.
The model is saved to {prefix}{model_config.name}{suffix}.yaml
and {prefix}{model_config.name}{suffix}npz. The configuration parameters
are saved to {model_config.model_name}.yaml.
:param model_config: Model testing configuration
:param opt_config: Optimization parameters
:param prefix: Prefix to attach to filenames
:param suffix: Suffix to attach to filenames
:return: A compiler pass that dumps diagnostic files
"""
def run_dump(net: AwesomeNet):
os.makedirs(model_config.output_directory, exist_ok=True)
dump_yaml_npz(model_config, net, name_prefix=prefix, name_postfix=suffix)
dump_configs_to_yaml(model_config, opt_config)
return CompileStep.pure(None)
return run_dump
[docs]
def dump_diagnostic_files_after(step: CompileStep[AwesomeNet],
model_config: ModelConfigs,
opt_config: OptimizationConfigs,
*, condition: bool = True,
prefix: str = "", suffix: str = "") \
-> CompileStep[AwesomeNet]:
"""
Run a compiler pass, then save intermediate compilation
results to files for diagnostic purposes. See dump_diagnostic_files for
details.
:param step: Compilation step to run before dumping results
:param model_config: Model testing configuration
:param opt_config: Optimization parameters
:param prefix: Prefix to attach to filenames
:param suffix: Suffix to attach to filenames
:param condition: Whether to dump files. If False, only the compilation step runs.
:return: Compilation step extended to dump files at the end
"""
if not condition:
return step
dump = dump_diagnostic_files(model_config, opt_config, prefix=prefix, suffix=suffix)
return (step
.then(lambda net: dump(net)
.then(lambda _: CompileStep.pure(net))))