#########################################################
# Copyright (C) 2022 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Ljubomir Papuga
#########################################################
import itertools
import numpy as np
from typing import List, Callable, Any, Optional, IO, Dict, Iterable, TypeVar, Generic, Tuple, Iterator
from sima_utils.common import print_progressbar
from afe.ir.defines import NodeName
from afe.ir.utils import transpose_tensor_according_to_layout_strings
# Type of ground truth data for checking the output of a model
_GroundTruth = TypeVar('_GroundTruth')
# Internal data type when ground truth is transformed in multiple steps
_GroundTruth2 = TypeVar('_GroundTruth2')
_A = TypeVar('_A')
_B = TypeVar('_B')
[docs]
def checked_zip(x: Iterable[_A], y: Iterable[_B]) -> Iterator[Tuple[_A, _B]]:
"""
Zip together two iterables that must have the same length. The returned
iterator behaves like zip, except that it raises an exception if one iterator
is longer than the other.
:param x: First iterable
:param y: Second iterable
:return: Iterable of pairs of values taken from x and y
"""
x_iter = iter(x)
y_iter = iter(y)
while True:
try:
x_next = next(x_iter)
except StopIteration:
# End of x found. y must end as well.
try:
next(y_iter)
except StopIteration:
return
raise ValueError("Zipped sequences do not have the same length")
# End of x not found. y must not end now.
try:
y_next = next(y_iter)
except StopIteration:
raise ValueError("Zipped sequences do not have the same length")
yield (x_next, y_next)
def _is_like_dataset_performance_analyzer(obj: Any) -> bool:
"""
Check whether the given object has the expected API
attributes of a dataset performance analyzer.
"""
return (hasattr(obj.__class__, "reset") and callable(obj.reset)
and hasattr(obj.__class__, "compare") and callable(obj.compare)
and hasattr(obj.__class__, "analyze_results") and callable(obj.analyze_results)
and hasattr(obj.__class__, "performance"))
[docs]
class ComposeDatasetCompareFunction(Generic[_GroundTruth]):
"""
ComposeDatasetCompareFunction(f, a) returns a dataset performance analyzer that
behaves like a, except that its compare function is equivalent to
def compare(out, gt):
out2, gt2 = f(out, gt)
return a.compare(out2, gt2)
See DataSetPerformanceAnalyzerProxy for method documentation.
"""
_dataset_performance_analyzer: Any
_transform_outputs: Callable[[List[np.ndarray], _GroundTruth], Tuple[List[np.ndarray], _GroundTruth2]]
def __init__(self, dataset_performance_analyzer: Any,
transform_outputs: Callable[[List[np.ndarray], _GroundTruth], Tuple[List[np.ndarray], _GroundTruth2]]):
assert _is_like_dataset_performance_analyzer(dataset_performance_analyzer)
self._dataset_performance_analyzer = dataset_performance_analyzer
self._transform_outputs = transform_outputs
[docs]
def reset(self) -> None:
# Delegate to the performance analyzer
self._dataset_performance_analyzer.reset()
[docs]
def compare(self, net_out: List[np.ndarray], gt_data: _GroundTruth) -> Optional[str]:
# Transform values, then delegate to the performance analyzer
net_out2, gt_data2 = self._transform_outputs(net_out, gt_data)
return self._dataset_performance_analyzer.compare(net_out2, gt_data2)
[docs]
def analyze_results(self) -> str:
# Delegate to the performance analyzer
return self._dataset_performance_analyzer.analyze_results()
@property
[docs]
class GraphEvaluatorLogger(object):
"""
Used for printing the progress and results of graph evaluation.
Attribute
----------
:attribute _verbose: bool. Whether to print out the progress and results.
If set to False, logging will be disabled.
:attribute _log_file: Optional[IO]. The IO object used to keep the graph evaluation logs, if any.
"""
_verbose: bool
_log_file: Optional[IO]
def __init__(self, verbose: bool, log_filename: Optional[str]):
self._verbose = verbose
self._log_file = None
if log_filename is not None:
self._log_file = open(log_filename, "w")
assert self._log_file is not None
[docs]
def print_progressbar(self, current_step: int, total_steps: int, analysis_str: str):
"""
Prints the progressbar if logging is enabled.
:param current_step: int. The current step in graph evaluation process.
:param total_steps: int. The total number of steps in evaluation process.
:param analysis_str: str. The output of the current step in evaluation.
"""
if self._verbose:
print_progressbar(current_step, total_steps, "Evaluation Progress:",
"Complete. " + analysis_str, length=30, print_end="")
[docs]
def print_analysis_str(self, analysis_str: str):
"""
Prints out the analysis string.
:param analysis_str: str. The output of the current step in evaluation.
"""
if self._verbose and analysis_str != "":
print(f"{analysis_str}")
if self._log_file is not None:
self._log_file.write(analysis_str + "\n")
[docs]
def print_analysis_summary(self, get_analysis_summary: Callable[[], str]):
"""
Prints out the summary of graph evaluation process.
:param get_analysis_summary: Callable. The function used to get the results of the evaluation.
"""
if self._verbose:
analysis_summary_str = get_analysis_summary()
print(analysis_summary_str)
if self._log_file is not None:
self._log_file.write(analysis_summary_str + "\n")
[docs]
def print_error_message(self):
"""
Prints out the warning message in case evaluation didn't yield any results.
"""
msg = "Cannot evaluate network performance, setting performance to zero."
print(msg)
if self._log_file is not None:
self._log_file.write(msg)
[docs]
class GraphEvaluator(object):
"""
Wrapper class encapsulating objects used in graph evaluation.
Attribute
---------
:attribute input_generator: DataGenerator. Used to generate input data used
in graph evaluation.
:attribute ground_truth_data_generator. DataGenerator. Used to generate ground truth
outputs used in graph evaluation.
:attribute dataset_performance_analyzer. PerformanceAnalyzerProxy. Used to perform
graph evaluation.
:attribute sample_count_hint: Optional[int]. Number of samples in the input, used for
progress reporting. Does not affect the number
of samples actually processed from the input.
If None, the number of samples in the input is
unknown and progress is not shown.
:attribute transpose_output: bool. Whether to transpose output from NHWC to NCHW layout.
"""
[docs]
ground_truth_data_generator: Iterable[_GroundTruth]
[docs]
sample_count_hint: Optional[int]
def __init__(self, input_generator: Iterable[Dict[NodeName, np.ndarray]],
ground_truth_data_generator: Iterable[_GroundTruth],
performance_analyzer: Any, sample_count_hint: int,
transpose_output: bool = False):
self.input_generator = input_generator
self.ground_truth_data_generator = ground_truth_data_generator
self.dataset_performance_analyzer = PerformanceAnalyzerProxy(performance_analyzer)
self.sample_count_hint = sample_count_hint
self.transpose_output = transpose_output
[docs]
def evaluate(self, run_func: Callable[[Dict[NodeName, np.ndarray]], List[np.ndarray]],
verbose: bool = False, analysis_log_filename: Optional[str] = None) -> float:
"""
Perform evaluation of the network performance. If an exception is raised while performing
the evaluation the performance is set to zero.
:param run_func: Callable[[Dict[NodeName, np.ndarray], List[np.ndarray]]. Function which takes
input dataset and generates inference results.
:param verbose: bool. Default is False. If set to True, print out the evaluation results.
:param analysis_log_filename: Optional[str]. Default is None. If given, represents the file
to which the evaluation results shall be logged.
:return: float. A number from 0 to 1 indicating the network's performance.
"""
logger = GraphEvaluatorLogger(verbose, analysis_log_filename)
try:
self.dataset_performance_analyzer.reset()
input_source = enumerate(checked_zip(self.input_generator, self.ground_truth_data_generator))
for i, (network_input, ground_truth_outputs) in input_source:
# Execute the model
net_output = run_func(network_input)
if self.transpose_output:
for idx in range(len(net_output)):
# Transpose only 4D tensors
if net_output[idx].ndim == 4:
net_output[idx] = transpose_tensor_according_to_layout_strings(net_output[idx],
"NHWC", "NCHW")
# Post-process and analyze the results
analysis_str = self.dataset_performance_analyzer.compare(net_output, ground_truth_outputs)
analysis_str = analysis_str or ""
if self.sample_count_hint is not None:
logger.print_progressbar(i + 1, self.sample_count_hint, analysis_str)
logger.print_analysis_str(analysis_str)
logger.print_analysis_summary(self.dataset_performance_analyzer.analyze_results)
return self.dataset_performance_analyzer.performance
except IndexError:
logger.print_error_message()
return 0.0