Source code for afe.ir.transform.calibration_transforms

#########################################################
# Copyright (C) 2020 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Joey Chou
#########################################################
from typing import Dict, Optional, Iterable, Iterator, Callable

import numpy as np
from termcolor import colored

from afe.apis.defines import CalibrationMethod, MinMaxMethod
from afe.core.configs import RunConfigs
from afe.ir.defines import Status, InputName, NodeName, LogNodeReporter
from afe.ir.net import AwesomeNet, update_awesomenet_status
from afe.ir.node import AwesomeNode, node_is_awesomenet, node_uses_observer, node_is_sima_ir
from afe.ir.node_observer import NodeObserver
import afe.ir.operations as afe_op
from afe.ir.attributes import ConvAddActivationAttrs
from afe.ir.sima_ir import SiMaIRTensorTypes, SiMaIR
from afe.ir.transform.base_transform import BaseTransform
from sima_utils.common import print_progressbar
from sima_utils.data.data_generator import DataGenerator
from sima_utils.logging import sima_logger
import time

[docs] DONE = colored("DONE", "yellow")
_EXECUTOR = RunConfigs(fast_mode=True)
[docs] def calibrate_node(node: AwesomeNode, inputs: Dict[InputName, SiMaIRTensorTypes], node_outputs: Dict[NodeName, SiMaIRTensorTypes]): """ Calibrate the AwesomeNode and get the (min, max) dynamic range. :param node: AwesomeNode. The node to be calibrated. :param inputs: Dict[InputName, Any]. Dictionary containing AwesomeNode inputs. :param node_outputs: Dict[NodeName, SiMaIRTensorTypes]. Dictionary containing AwesomeNode outputs. """ assert isinstance(node.ir, SiMaIR) outputs = node.ir.calibrate(inputs, _EXECUTOR) node.status = Status.CALIBRATED node_outputs[node.name] = outputs
[docs] def per_channel_calibrate_node(node: AwesomeNode, inputs: Dict[InputName, SiMaIRTensorTypes], node_outputs: Dict[NodeName, SiMaIRTensorTypes]): """ Run per-channel calibration on the AwesomeNode. This calibration is for SmoothQuant and channel equalization. It does not collect all information that is needed for quantization. :param node: AwesomeNode. The node to be calibrated. :param inputs: Dict[InputName, Any]. Dictionary containing AwesomeNode inputs. :param node_outputs: Dict[NodeName, SiMaIRTensorTypes]. Dictionary containing AwesomeNode outputs. """ assert isinstance(node.ir, SiMaIR) outputs = node.ir.run(inputs, _EXECUTOR) observer = node.ir.calib_attrs.observer if observer is not None: # Observers only monitor single output nodes assert isinstance(outputs, np.ndarray) observer.update(outputs.astype(np.float32, copy=False)) node_outputs[node.name] = outputs
def _update_fp_input_range(net: AwesomeNet, input: Dict): """ Update floating point input range of AwesomeNet. """ fp_input_range = {} for k, v in input.items(): node_name = k assert isinstance(v, np.ndarray) input_range = [float(v.min()), float(v.max())] fp_input_range[node_name] = input_range net.fp_input_range = fp_input_range
[docs] class InsertNodeObservers(BaseTransform): """ Inserts NodeObservers into AwesomeNet's nodes. NodeObservers will not be added to nodes that don't use information obtained from calibration pass, but obtain quantization parameters based on the quantization parameters of its inputs. :param calibration_method: CalibrationMethod used in calibration. Determines the type of observers which shall be created. See CalibrationMethod Enum class for supported values. :param percentile_value: If the Histogram percentile is being used, percentage of values to keep. """ def __init__(self, calibration_method: CalibrationMethod): self._calibration_method = calibration_method def __call__(self, net: AwesomeNet) -> None: for node in net.nodes.values(): if node_is_awesomenet(node): self(node.ir) else: # Don't add observer to subgraph node if node_uses_observer(node): node.ir.calib_attrs.observer = NodeObserver(calibration_method=self._calibration_method, asymmetry=node.ir.quant_config.asymmetry.get(), node_reporter=LogNodeReporter(node.name), qdq_quantization=node.ir.calib_attrs.precomputed_quant) # Create a dict of intermediate observers. Always use min_max calibration method. node.ir.calib_attrs.intermediate_observers = { name: NodeObserver(calibration_method=MinMaxMethod(), asymmetry=True, node_reporter=LogNodeReporter(node.name)) for name in node.ir.operation.intermediate_names } node_attrs = node.ir.attrs if isinstance(node_attrs, ConvAddActivationAttrs) and node_attrs.bias_attrs is not None: node.ir.calib_attrs.intermediate_observers = {'mean': NodeObserver(do_mean_estimation=True)}
def _node_supports_channel_equalization(node: AwesomeNode) -> bool: """ Return True if the node can be transformed by SmoothQuant or channel equalization. The decision is based on the node's operator. """ return node_is_sima_ir(node) and isinstance(node.ir.operation, (afe_op.ConvAddActivationOp, afe_op.LayerNormOp))
[docs] class InsertPerChannelNodeObservers(BaseTransform): """ Inserts NodeObservers that collect per-channel statistics in AwesomeNet's nodes. To save analysis time, NodeObservers will only be added to layers where the statistics are useful for SmoothQuant and channel equalization. """ def __init__(self): pass def __call__(self, net: AwesomeNet) -> None: for node in net.nodes.values(): if node_is_awesomenet(node): self(node.ir) elif _node_supports_channel_equalization(node): node.ir.calib_attrs.observer = NodeObserver(per_channel_min_max=True, asymmetry=node.ir.quant_config.asymmetry.get(), node_reporter=LogNodeReporter(node.name)) else: node.ir.calib_attrs.observer = None
[docs] class RemoveNodeObservers(BaseTransform): """ Remove NodeObservers from AwesomeNet's nodes. This is needed in order to remove PyTorch infrastructure which is no longer needed from AwesomeNet's nodes. """ def __call__(self, net: AwesomeNet) -> None: for node in net.nodes.values(): if node_is_awesomenet(node): self(node.ir) else: node.ir.calib_attrs.observer = None node.ir.calib_attrs.intermediate_observers = None
def _run_calibration_loop(name: str, length_hint: Optional[int], dataset: Iterable[Dict[str, np.ndarray]], do_calibrate: Callable[[Dict[str, np.ndarray]], None]): """ Loop over a calibration data set and call do_calibrate for each item. This function handles progress messages. :param name: The name of this calibration pass. This name will be used in status messages. :param length_hint: The hinted size of the dataset. It is used to show a progress bar or log message. It will not limit the actual size of the dataset. :param dataset: Dataset to analyze :param do_calibrate: Calibration function """ _msg = f"Running {name} ..." print(colored(_msg, "green"), end="\r") sima_logger.sima_log_info(_msg) start_time = time.time() calibration_ran = False for i, calibration_input in enumerate(dataset): if length_hint is not None: # Known input size; show progress print_progressbar(i + 1, length_hint, name + " Progress:", "Complete. {}/{}".format(i + 1, length_hint), length=30, print_end="") do_calibrate(calibration_input) calibration_ran = True # If 5 seconds passed since last time, log a progress message current_time = time.time() if current_time - start_time > 5.0: sima_logger.sima_log_info(f"{name} progress: completed {i + 1} samples") start_time = current_time assert calibration_ran, "Unable to calibrate. Input data set is empty." print(colored(_msg, "green") + DONE)
[docs] class Calibrate(BaseTransform): """ Calibrates an AwesomeNet. :param length_hint: If not None, gives the number of items in the data set. This is used for progress reporting and it does not affect the number of items used for calibration. :param dataset: The input data set for calibration. """ def __init__(self, length_hint: Optional[int], dataset: Iterable[Dict[str, np.ndarray]]): assert not isinstance(dataset, DataGenerator) # Disallow former usage of DataGenerator self._length_hint = length_hint self._dataset = dataset def __call__(self, net: AwesomeNet) -> None: def do_calibrate(input: Dict[str, np.ndarray]): _ = net.run(input, node_callable=calibrate_node) # Update AwesomeNet floating point input range _update_fp_input_range(net, input) _run_calibration_loop("Calibration", self._length_hint, self._dataset, do_calibrate) # Update AwesomeNet status update_awesomenet_status(net, Status.CALIBRATED)
[docs] class CalibratePerChannel(BaseTransform): """ Calibrates an AwesomeNet for analyses that require per-channel information. This calibration pass is specialized for SmoothQuant and channel equalization. """ def __init__(self, length_hint: Optional[int], dataset: Iterable[Dict[str, np.ndarray]]): assert not isinstance(dataset, DataGenerator) # Disallow former usage of DataGenerator self._length_hint = length_hint self._dataset = dataset def __call__(self, net: AwesomeNet) -> None: def do_calibrate(input: Dict[str, np.ndarray]): _ = net.run(input, node_callable=per_channel_calibrate_node) _run_calibration_loop("Equalizer Calibration", self._length_hint, self._dataset, do_calibrate)