#########################################################
# Copyright (C) 2021 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Joey Chou
#########################################################
import numpy as np
from termcolor import colored
from typing import List, Tuple, Union, Optional, Dict, Any, Callable, Set, Iterable
from afe.common_utils import search_matched_node_names, get_index_from_node_name
from afe.core.graph_analyzer.analyzed_results import AnalyzedResultDict
from afe.core.graph_analyzer.graph_analyzer import QuantizedGraphAnalyzer
from afe.core.graph_analyzer.utils import QuantizedGraphAnalyzerMode, Metric
from afe.ir.net import AwesomeNet
from afe.ir.node import node_is_awesomenet
from afe.ir.operations import ConstantOp
from afe.ir.sima_ir import SiMaIR
[docs]
def fix_nodes_to_float(net: AwesomeNet,
patterns: List[Union[str, Tuple[int, int]]],
excluded_patterns: Optional[List[Union[str, Tuple[int, int]]]] = None,
verbose: bool = True,
fix_constant_input_nodes: bool = False) -> None:
"""
Given patterns in List[Union[str, int, Tuple[int, int]]], generate a node name pattern set
using the different types of pattern as below:
1. str: Add to the set. Support wildcard.
2. int: Add the *_{number} to the set.
3. Tuple[int, int]: Unroll the tuple of two integers to a range of integers
where the lower bound is the first integer of the tuple and the upper
bound is the second integer of the tuple. Each integer will be converted to
*_{number} string and add to the set.
Use the set to pattern match each node names in the given AwesomeNet. Add the matched node
name to net.float_node_list. Each sub awesome-net has its own float_node_list which is updated.
The node will still be calibrated so the downstream nodes can have its zero points and scales,
but it won't get quantized.
Example
-------
The example add the nodes with indices equal to [2, 3, 10, 11, 12, 13, 14] and all the
nodes containing "conv" in the node name to net.float_node_list. The node with
"conv2d_transpose" in the node name will be excluded and will not be set to fixed float.
Set the verbose to True to print out the node names that are added to float_node_list.
.. code-block:: python
patterns = [2, (11, 14), "*conv*", "3", "10"]
excluded_patterns = ["*conv2d_transpose*"]
fix_nodes_to_float(net, patterns, excluded_patterns, verbose=True)
Parameters
----------
:param net: AwesomeNet. The AwesomeNet contains nodes that user want to fixed to float.
:param patterns: List[Union[str, int, Tuple[int, int]]]. List of patterns.
:param excluded_patterns: Optional[List[Union[str, int, Tuple[int, int]]]]. Default is None.
List of patterns that will be excluded from the pattern matching. Node names that contains
these patterns will be excluded from the matched patterns.
:param verbose: bool. Default is True. Set to False if user want to hide the print message
:param fix_constant_input_nodes: bool. Default is True. Set to True if user wants to also fix
all constant input nodes of nodes matching patterns. This will avoid errors in cases where
fixing one node to floating point would lead to its constant input missing quantization
data. Otherwise, set to False.
Return
-----
:return: None
"""
nodes_to_be_fixed = search_matched_node_names(net.nodes.keys(), patterns, excluded_patterns)
for node in net.nodes.values():
if node_is_awesomenet(node):
fix_nodes_to_float(node.ir, patterns, excluded_patterns)
if fix_constant_input_nodes:
fix_constant_input_nodes_to_float(net, nodes_to_be_fixed)
for name in nodes_to_be_fixed:
if verbose:
print(f"Fixed node {name} to float")
net.extend_float_node_list(list(nodes_to_be_fixed))
[docs]
def get_critical_node_local_feed(results: Dict[Metric, AnalyzedResultDict],
float_nodes: List[str],
metric: Metric,
verbose: bool) -> str:
"""
Determine the critical node with worse quantization results for given metric, for analysis done in
QuantizedGraphAnalyzerMode.local_feed mode.
Parameters
----------
:param results: Dict[Metric, AnalyzedResultDict]. Results of the graph quantization analysis.
:param float_nodes: List[str]. If given, the list of nodes already set to fp32.
:param metric: Metric. Defines metric used for determining the critical node.
:param verbose: bool. Default is False. If set to True the function will print out critical node name.
Return
-----
:return: str. Name of the critical node.
"""
# TODO: Analyze what metric is to be used for setting the critical node.
assert metric in results
metric_results = results[metric]
critical_node: Optional[str] = None
critical_value: float = 100.0 if metric == Metric.psnr else 0.0
compare_fn: Callable[[float, float], bool] = float.__lt__ if metric == Metric.psnr else float.__gt__
def set_critical(r: Union[float, Tuple[Any, ...], List[Any]], n: str, cv: float, cn: Optional[str]) -> \
Tuple[float, Optional[str]]:
if isinstance(r, float):
if compare_fn(r, cv) and \
(float_nodes is None or n not in float_nodes) and \
get_index_from_node_name(n) != 0:
cv = r
cn = n
else:
for _r in r:
cv, cn = set_critical(_r, n, cv, cn)
return cv, cn
for node_name, res in metric_results.items():
critical_value, critical_node = set_critical(res, node_name, critical_value, critical_node)
assert critical_node is not None
if verbose:
print(colored(f"Critical node is {critical_node}, {metric} = {critical_value}", "red"))
return critical_node
[docs]
def get_critical_node_global_feed(results: Dict[Metric, AnalyzedResultDict],
float_nodes: List[str],
metric: Metric,
verbose: bool) -> str:
"""
Determine the critical node with worse quantization results for given metric, for analysis done in
QuantizedGraphAnalyzerMode.global_feed mode.
Parameters
----------
:param results: Dict[Metric, AnalyzedResultDict]. Results of the graph quantization analysis.
:param float_nodes: List[str]. If given, the list of nodes already set to fp32.
:param metric: Metric. Defines metric used for determining the critical node.
:param verbose: bool. Default is False. If set to True the function will print out critical node name.
Return
-----
:return: str. Name of the critical node.
"""
# TODO: Add algorithm determining the critical node when using global_feed.
raise NotImplementedError("Improving quantized net using "
"QuantizedGraphAnalyzerMode.global_feed is not supported")
[docs]
def get_critical_node(results: Dict[Metric, AnalyzedResultDict],
float_nodes: List[str],
mode: Union[str, QuantizedGraphAnalyzerMode],
metric: Metric,
verbose: bool = False) -> str:
"""
Determine the critical node with worse quantization results for given metric.
Parameters
----------
:param results: Dict[Metric, AnalyzedResultDict]. Results of the graph quantization analysis.
:param float_nodes: List[str]. If given, the list of nodes already set to fp32.
:param mode: Union[str, QuantizedGraphAnalyzerMode]
:param metric: Metric. Defines metric used for determining the critical node.
:param verbose: bool. Default is False. If set to True the function will print out critical node name.
Return
-----
:return: str. Name of the critical node.
"""
if isinstance(mode, str):
mode = QuantizedGraphAnalyzerMode(mode)
elif not isinstance(mode, QuantizedGraphAnalyzerMode):
raise TypeError(f"Parameter mode is of unsupported type: {type(mode)}")
if mode == QuantizedGraphAnalyzerMode.local_feed:
return get_critical_node_local_feed(results, float_nodes, metric, verbose)
elif mode == QuantizedGraphAnalyzerMode.global_feed:
return get_critical_node_global_feed(results, float_nodes, metric, verbose)
else:
raise ValueError(f"Unsupported QuantizedGraphAnalyzerMode: {mode}")