Source code for afe.apis.statistic

#########################################################
# Copyright (C) 2022 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Christopher Rodrigues
#########################################################
"""
Analysis of statistics on tensors.
"""
import dataclasses
from typing import Callable, Generic, TypeVar, Type, Any, Tuple, List

import numpy as np

from afe.apis.compilation_job_base import Tensor
from afe.driver.statistic import Statistic
# Exporting
from afe.core.evaluate_networks import checked_zip


# A distance metric on tensors
[docs] Metric = Callable[[Tensor, Tensor], float]
[docs] def equality(x: Tensor, y: Tensor) -> float: """ Equality as a distance metric. Return 0 if the tensors are equal, 1 otherwise. """ return 0 if (x == y).all() else 1
[docs] def mean_float(x: Tensor, y: Tensor) -> float: """ Mean value of difference between input values and ground truth values. """ return float(np.mean(np.abs(y - x)))
@dataclasses.dataclass class _ThresholdCounterState: sample_count: int = 0 success_count: int = 0 @dataclasses.dataclass class _MeanValues: mean_values: List[float] = dataclasses.field(default_factory=list)
[docs] def threshold_test_counter(metric: Metric, threshold: float) \ -> Statistic[Tuple[Tensor, Tensor], str]: """ Create a Statistic over a stream of (x, y) pairs that counts the number of times metric(x, y) < threshold is satisfied. :param metric: Distance metric :param threshold: Threshold to compare against :return: Statistic that captures the data """ def initialize() -> _ThresholdCounterState: return _ThresholdCounterState() def update(state: _ThresholdCounterState, x: Tuple[Tensor, Tensor]) -> None: x1, x2 = x state.sample_count += 1 state.success_count += 1 if metric(x1, x2) < threshold else 0 def finish(state: _ThresholdCounterState) -> str: return f"{state.success_count} of {state.sample_count} tests passed" return Statistic(initialize, update, finish)
[docs] def tensor_set_statistics(statistics: List[Statistic[Tuple[Any, Any], str]]) \ -> Statistic[Tuple[List[Any], List[Any]], str]: """ Apply an independent Statistic to each tensor in a stream of pairs of fixed-length lists. This is intended for evaluating models that have multiple outputs and a ground truth value corresponding to each output. Each Statistic would be applied to one of the outputs and ground truth values. :param statistics: Statistic to apply to each pair of values :return: Composed statistic that applies the statistics to list items """ def initialize() -> List[Any]: return [s.initialize() for s in statistics] def update(state: Any, x: Tuple[List[Any], List[Any]]) -> None: tensors, ground_truths = x assert len(statistics) == len(state) == len(tensors) == len(ground_truths) for a, s, t, g in zip(statistics, state, tensors, ground_truths): a.update(s, (t, g)) def finish(state: Any) -> str: assert len(statistics) == len(state) messages = [] for a, s in zip(statistics, state): messages.append(a.finish(s)) return "\n".join(messages) return Statistic(initialize, update, finish)
[docs] def mean(metric: Metric) -> Statistic[Tuple[List[Tensor], Tensor], float]: """ Create a statistic that takes input pairs (i, g) and computes the arithmetic mean of metric(i, g) over all given inputs. :param metric: Mean metric. :return: Statistic that captures the data """ def initialize() -> _MeanValues: return _MeanValues() def update(state: _MeanValues, x: Tuple[Tensor, Tensor]) -> None: x1, x2 = x state.mean_values.append(metric(x1, x2)) def finish(state: _MeanValues) -> float: return float(np.mean(state.mean_values)) return Statistic(initialize, update, finish)
[docs] def mean_text(metric: Metric) -> Statistic[Tuple[List[Tensor], Tensor], str]: """ Create a statistic that takes input pairs (i, g) and computes the arithmetic mean of metric(i, g) over all given inputs and formats the results as text message. :param metric: Mean metric. :return: Statistic that captures the data """ def initialize() -> _MeanValues: return _MeanValues() def update(state: _MeanValues, x: Tuple[Tensor, Tensor]) -> None: x1, x2 = x state.mean_values.append(metric(x1, x2)) def finish(state: _MeanValues) -> str: return f'Arithmetic mean for {len(state.mean_values)} given inputs: {float(np.mean(state.mean_values))}' return Statistic(initialize, update, finish)