Source code for afe.ir.debug

#########################################################
# Copyright (C) 2021 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Joey Chou
#########################################################
import numpy as np
from termcolor import colored
from typing import List, Tuple, Union, Optional, Dict, Any, Callable, Set, Iterable

from afe.common_utils import search_matched_node_names, get_index_from_node_name
from afe.core.graph_analyzer.analyzed_results import AnalyzedResultDict
from afe.core.graph_analyzer.graph_analyzer import QuantizedGraphAnalyzer
from afe.core.graph_analyzer.utils import QuantizedGraphAnalyzerMode, Metric
from afe.ir.net import AwesomeNet
from afe.ir.node import node_is_awesomenet
from afe.ir.operations import ConstantOp
from afe.ir.sima_ir import SiMaIR


[docs] def fix_constant_input_nodes_to_float(net: AwesomeNet, nodes_to_be_fixed: Set[str]): """ Given the set of nodes that need to be fixed to floating point, find all their constant input nodes and fix them to floating point as well. :param net: AwesomeNet. Input AwesomeNet which nodes are subject to fixing in floating point. :param nodes_to_be_fixed: Set[str]. Set of names of nodes that are to be fixed to floating point and expanded with constant input nodes. """ constant_nodes: Set[str] = set() for node_name in nodes_to_be_fixed: for input_node_name in net.nodes[node_name].input_node_names: input_node = net.nodes[input_node_name] if isinstance(input_node.ir, SiMaIR) and isinstance(input_node.ir.operation, ConstantOp): constant_nodes.add(input_node_name) nodes_to_be_fixed.update(constant_nodes)
[docs] def fix_nodes_to_float(net: AwesomeNet, patterns: List[Union[str, Tuple[int, int]]], excluded_patterns: Optional[List[Union[str, Tuple[int, int]]]] = None, verbose: bool = True, fix_constant_input_nodes: bool = False) -> None: """ Given patterns in List[Union[str, int, Tuple[int, int]]], generate a node name pattern set using the different types of pattern as below: 1. str: Add to the set. Support wildcard. 2. int: Add the *_{number} to the set. 3. Tuple[int, int]: Unroll the tuple of two integers to a range of integers where the lower bound is the first integer of the tuple and the upper bound is the second integer of the tuple. Each integer will be converted to *_{number} string and add to the set. Use the set to pattern match each node names in the given AwesomeNet. Add the matched node name to net.float_node_list. Each sub awesome-net has its own float_node_list which is updated. The node will still be calibrated so the downstream nodes can have its zero points and scales, but it won't get quantized. Example ------- The example add the nodes with indices equal to [2, 3, 10, 11, 12, 13, 14] and all the nodes containing "conv" in the node name to net.float_node_list. The node with "conv2d_transpose" in the node name will be excluded and will not be set to fixed float. Set the verbose to True to print out the node names that are added to float_node_list. .. code-block:: python patterns = [2, (11, 14), "*conv*", "3", "10"] excluded_patterns = ["*conv2d_transpose*"] fix_nodes_to_float(net, patterns, excluded_patterns, verbose=True) Parameters ---------- :param net: AwesomeNet. The AwesomeNet contains nodes that user want to fixed to float. :param patterns: List[Union[str, int, Tuple[int, int]]]. List of patterns. :param excluded_patterns: Optional[List[Union[str, int, Tuple[int, int]]]]. Default is None. List of patterns that will be excluded from the pattern matching. Node names that contains these patterns will be excluded from the matched patterns. :param verbose: bool. Default is True. Set to False if user want to hide the print message :param fix_constant_input_nodes: bool. Default is True. Set to True if user wants to also fix all constant input nodes of nodes matching patterns. This will avoid errors in cases where fixing one node to floating point would lead to its constant input missing quantization data. Otherwise, set to False. Return ----- :return: None """ nodes_to_be_fixed = search_matched_node_names(net.nodes.keys(), patterns, excluded_patterns) for node in net.nodes.values(): if node_is_awesomenet(node): fix_nodes_to_float(node.ir, patterns, excluded_patterns) if fix_constant_input_nodes: fix_constant_input_nodes_to_float(net, nodes_to_be_fixed) for name in nodes_to_be_fixed: if verbose: print(f"Fixed node {name} to float") net.extend_float_node_list(list(nodes_to_be_fixed))
[docs] def get_critical_node_local_feed(results: Dict[Metric, AnalyzedResultDict], float_nodes: List[str], metric: Metric, verbose: bool) -> str: """ Determine the critical node with worse quantization results for given metric, for analysis done in QuantizedGraphAnalyzerMode.local_feed mode. Parameters ---------- :param results: Dict[Metric, AnalyzedResultDict]. Results of the graph quantization analysis. :param float_nodes: List[str]. If given, the list of nodes already set to fp32. :param metric: Metric. Defines metric used for determining the critical node. :param verbose: bool. Default is False. If set to True the function will print out critical node name. Return ----- :return: str. Name of the critical node. """ # TODO: Analyze what metric is to be used for setting the critical node. assert metric in results metric_results = results[metric] critical_node: Optional[str] = None critical_value: float = 100.0 if metric == Metric.psnr else 0.0 compare_fn: Callable[[float, float], bool] = float.__lt__ if metric == Metric.psnr else float.__gt__ def set_critical(r: Union[float, Tuple[Any, ...], List[Any]], n: str, cv: float, cn: Optional[str]) -> \ Tuple[float, Optional[str]]: if isinstance(r, float): if compare_fn(r, cv) and \ (float_nodes is None or n not in float_nodes) and \ get_index_from_node_name(n) != 0: cv = r cn = n else: for _r in r: cv, cn = set_critical(_r, n, cv, cn) return cv, cn for node_name, res in metric_results.items(): critical_value, critical_node = set_critical(res, node_name, critical_value, critical_node) assert critical_node is not None if verbose: print(colored(f"Critical node is {critical_node}, {metric} = {critical_value}", "red")) return critical_node
[docs] def get_critical_node_global_feed(results: Dict[Metric, AnalyzedResultDict], float_nodes: List[str], metric: Metric, verbose: bool) -> str: """ Determine the critical node with worse quantization results for given metric, for analysis done in QuantizedGraphAnalyzerMode.global_feed mode. Parameters ---------- :param results: Dict[Metric, AnalyzedResultDict]. Results of the graph quantization analysis. :param float_nodes: List[str]. If given, the list of nodes already set to fp32. :param metric: Metric. Defines metric used for determining the critical node. :param verbose: bool. Default is False. If set to True the function will print out critical node name. Return ----- :return: str. Name of the critical node. """ # TODO: Add algorithm determining the critical node when using global_feed. raise NotImplementedError("Improving quantized net using " "QuantizedGraphAnalyzerMode.global_feed is not supported")
[docs] def get_critical_node(results: Dict[Metric, AnalyzedResultDict], float_nodes: List[str], mode: Union[str, QuantizedGraphAnalyzerMode], metric: Metric, verbose: bool = False) -> str: """ Determine the critical node with worse quantization results for given metric. Parameters ---------- :param results: Dict[Metric, AnalyzedResultDict]. Results of the graph quantization analysis. :param float_nodes: List[str]. If given, the list of nodes already set to fp32. :param mode: Union[str, QuantizedGraphAnalyzerMode] :param metric: Metric. Defines metric used for determining the critical node. :param verbose: bool. Default is False. If set to True the function will print out critical node name. Return ----- :return: str. Name of the critical node. """ if isinstance(mode, str): mode = QuantizedGraphAnalyzerMode(mode) elif not isinstance(mode, QuantizedGraphAnalyzerMode): raise TypeError(f"Parameter mode is of unsupported type: {type(mode)}") if mode == QuantizedGraphAnalyzerMode.local_feed: return get_critical_node_local_feed(results, float_nodes, metric, verbose) elif mode == QuantizedGraphAnalyzerMode.global_feed: return get_critical_node_global_feed(results, float_nodes, metric, verbose) else: raise ValueError(f"Unsupported QuantizedGraphAnalyzerMode: {mode}")
[docs] def improve_quantized_net_performance( calibrated_net: AwesomeNet, quantized_net: AwesomeNet, reference_net: AwesomeNet, input_data_set: Iterable[Dict[str, np.ndarray]], mode: QuantizedGraphAnalyzerMode = QuantizedGraphAnalyzerMode.local_feed, metric: Metric = Metric.mse ) -> str: """ Improves precision of calibrated AwesomeNet by running the graph quantization analysis, finding the critical node and fixing the critical node to floating point. Parameters ---------- :param calibrated_net: AwesomeNet. Calibrated AwesomeNet which precision needs to be improved. :param quantized_net: AwesomeNet. Quantized AwesomeNet which precision needs to be improved. :param reference_net: AwesomeNet. Reference FP32 AwesomeNet. :param input_data_set: Iterable of input data used for quantization analysis. It is recommended to use a smaller data set than calibration would use. :param mode: QuantizedGraphAnalyzerMode. Default is QuantizedGraphAnalyzerMode.local_feed. Mode in which the GraphAnalyzer shall be used. :param metric: Metric. Default is Metric.mse. Defines which metric is used in determining the critical node. :return: str. Name of the node that needs to be fixed to higher precision. """ # Create GraphAnalyzer. # Use QuantizedGraphAnalyzerMode.local_feed analyzer = QuantizedGraphAnalyzer(mode) # Analyze the results. analyzer.analyze(reference_net, quantized_net, input_data_set, [metric]) node_to_be_fixed = get_critical_node(analyzer.analyzed_results, quantized_net.float_node_list, mode, metric) # Fix the calibrated net. fix_nodes_to_float(calibrated_net, [node_to_be_fixed], fix_constant_input_nodes=True) return node_to_be_fixed