#########################################################
# Copyright (C) 2022 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Joey Chou
#########################################################
import numpy as np
from typing import Callable
from sima_utils.analysis import (
calculate_l1, calculate_l2, calculate_mae,
calculate_mse, calculate_psnr, calculate_mean,
calculate_std
)
from afe.common_utils import EnumHelper
from afe.ir.defines import TensorFormat
from afe.core.graph_analyzer.analyzed_results import (
AnalyzedResultType, AnalyzedResultDict
)
[docs]
class BaseGraphAnalyzerMode(str, EnumHelper):
"""
Base class for GraphAnalyzer mode. Overload _missing_ method
to print out more informative error message.
"""
pass
[docs]
class QuantizedGraphAnalyzerMode(BaseGraphAnalyzerMode):
"""
Modes of the QuantizedGraphAnalyzer. Support:
* global_feed:
Execute the both fp32 AwesomeNet and quantized AwesomeNet using the same inputs.
Compare intermediates between both AwesomeNets and calculate the targeted metrics
* local_feed:
Execute the fp32 AwesomeNet. Execute the given quantized AwesomeNet using the
intermediates from fp32 AwesomeNet. When execute each node in the quantized
AwesomeNet, instead using the output from its previous node(s), it uses the
intermediates from the fp32 AwesomeNet. Each input to the node will be quantized to
int8 before execution. Compare intermediates between both AwesomeNets and calculate
the targeted metrics
"""
[docs]
global_feed = "global_feed"
[docs]
local_feed = "local_feed"
[docs]
class Metric(str, EnumHelper):
"""
Enum class for different metric
Parameters
----------
:param l1: L1 error
:param l2: L2 error
:param mae: Mean absolute error
:param mse: Mean square error
:param psnr: Peak signal-to-noise ratio
:param mean: Mean
:param std: Standard deviation
"""
#######
# Utils
#######
def _filter_base_on_threshold(analyzed_results: AnalyzedResultDict,
threshold: float,
filter_func: Callable[[float, float], bool]
) -> AnalyzedResultDict:
"""
Given an analyzed results dictionary, filter out the results using the
given function. The filter function should take two inputs and return
a bool.
Example
-------
.. code-block:: python
lowpass_filter_func = lambda x, threshold: x < threshold
analyzed_results = _filter_base_on_threshold(
analyzed_results, threshold, filter_func)
Parameters
----------
:param analyzed_results: AnalyzedResultDict. The dictionary contains the
analyzed_results.
:param threshold: float. Threshold that will be use to filter out unwanted
analyzed_results
:param filter_func: Callable[[float, float], bool]. The filter function that
will be used with the threshold. The first input is the analyze results
and the second input is the threshold. The output of the filter_func
will be bool. If True, the analyzed result will be kept.
Return
------
:return: AnalyzedResultDict. An AnalyzedResultDict dictionary that contains the targeted
results after filtering
"""
assert isinstance(analyzed_results, AnalyzedResultDict), \
("The analyzed results filer expect a AnalyzedResultDict input dictionary. ",
f"Got {type(analyzed_results)}")
thresholded_results = AnalyzedResultDict()
for k, v in analyzed_results.items():
if isinstance(v, AnalyzedResultDict):
# If the value is another AnalyzedResultDict, apply the filter recursively
thresholded_results[k] = _filter_base_on_threshold(v, threshold, filter_func)
elif isinstance(v, list):
for analyzed_res in v:
# 1. If the value is a float number, it means the value is the analyzed result
# 2. If the value is a tuple/list of float, it means the value contains multiple
# analyzed results. Filter them one by one.
if (isinstance(analyzed_res, (float, np.float32)) and filter_func(analyzed_res, threshold)) or \
(isinstance(analyzed_res, (tuple, list)) and any([filter_func(_v, threshold) for _v in analyzed_res])):
thresholded_results[k] = v
break
return thresholded_results
[docs]
def find_below_threshold(analyzed_results: AnalyzedResultDict,
threshold: float) -> AnalyzedResultDict:
"""
A function that returns only the analyzed results below the targeted threshold.
Parameters
----------
:param analyzed_results: AnalyzedResultDict. The AnalyzedResultDict that will be filtered.
:param threshold: float. The threshold value that will be used by the filter function.
Return
------
:return: AnalyzedResultDict. A filtered AnalyzedResultDict.
"""
return _filter_base_on_threshold(analyzed_results, threshold, filter_func=lambda x, y: x < y)
[docs]
def find_above_threshold(analyzed_results: AnalyzedResultDict,
threshold: float) -> AnalyzedResultDict:
"""
A function that returns only the analyzed results above the targeted threshold.
Parameters
----------
:param analyzed_results: AnalyzedResultDict. The AnalyzedResultDict that will be filtered.
:param threshold: float. The threshold value that will be used by the filter function.
Return
------
:return: AnalyzedResultDict. A filtered AnalyzedResultDict.
"""
return _filter_base_on_threshold(analyzed_results, threshold, filter_func=lambda x, y: x > y)
_METRIC_FUNC_DICT = {
# Two AwesomeNets metrics
Metric.l1: double_inputs_tuple_list_datatype_helper(calculate_l1),
Metric.l2: double_inputs_tuple_list_datatype_helper(calculate_l2),
Metric.mae: double_inputs_tuple_list_datatype_helper(calculate_mae),
Metric.mse: double_inputs_tuple_list_datatype_helper(calculate_mse),
Metric.psnr: double_inputs_tuple_list_datatype_helper(calculate_psnr),
# Single AwesomeNets metrics
Metric.mean: single_input_tuple_list_datatype_helper(calculate_mean),
Metric.std: single_input_tuple_list_datatype_helper(calculate_std),
}
[docs]
def get_metric_func(metric: Metric) -> Callable:
"""
Function to get the metric calculation function.
Parameters
----------
:param metric: Metric
Return
------
:return: Callable. A metric calculation function
"""
return _METRIC_FUNC_DICT[metric]