#########################################################
# Copyright (C) 2022 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Christopher Rodrigues
#########################################################
"""
Analysis of statistics on tensors.
"""
import dataclasses
from typing import Callable, Generic, TypeVar, Type, Any, Tuple, List
import numpy as np
from afe.apis.compilation_job_base import Tensor
from afe.driver.statistic import Statistic
# Exporting
from afe.core.evaluate_networks import checked_zip
# A distance metric on tensors
[docs]
Metric = Callable[[Tensor, Tensor], float]
[docs]
def equality(x: Tensor, y: Tensor) -> float:
"""
Equality as a distance metric.
Return 0 if the tensors are equal, 1 otherwise.
"""
return 0 if (x == y).all() else 1
[docs]
def mean_float(x: Tensor, y: Tensor) -> float:
"""
Mean value of difference between input values and ground truth values.
"""
return float(np.mean(np.abs(y - x)))
@dataclasses.dataclass
class _ThresholdCounterState:
sample_count: int = 0
success_count: int = 0
@dataclasses.dataclass
class _MeanValues:
mean_values: List[float] = dataclasses.field(default_factory=list)
[docs]
def threshold_test_counter(metric: Metric, threshold: float) \
-> Statistic[Tuple[Tensor, Tensor], str]:
"""
Create a Statistic over a stream of (x, y) pairs
that counts the number of times metric(x, y) < threshold is satisfied.
:param metric: Distance metric
:param threshold: Threshold to compare against
:return: Statistic that captures the data
"""
def initialize() -> _ThresholdCounterState:
return _ThresholdCounterState()
def update(state: _ThresholdCounterState, x: Tuple[Tensor, Tensor]) -> None:
x1, x2 = x
state.sample_count += 1
state.success_count += 1 if metric(x1, x2) < threshold else 0
def finish(state: _ThresholdCounterState) -> str:
return f"{state.success_count} of {state.sample_count} tests passed"
return Statistic(initialize, update, finish)
[docs]
def tensor_set_statistics(statistics: List[Statistic[Tuple[Any, Any], str]]) \
-> Statistic[Tuple[List[Any], List[Any]], str]:
"""
Apply an independent Statistic to each tensor in a stream of pairs of
fixed-length lists.
This is intended for evaluating models that have multiple outputs and a
ground truth value corresponding to each output. Each Statistic would
be applied to one of the outputs and ground truth values.
:param statistics: Statistic to apply to each pair of values
:return: Composed statistic that applies the statistics to list items
"""
def initialize() -> List[Any]:
return [s.initialize() for s in statistics]
def update(state: Any, x: Tuple[List[Any], List[Any]]) -> None:
tensors, ground_truths = x
assert len(statistics) == len(state) == len(tensors) == len(ground_truths)
for a, s, t, g in zip(statistics, state, tensors, ground_truths):
a.update(s, (t, g))
def finish(state: Any) -> str:
assert len(statistics) == len(state)
messages = []
for a, s in zip(statistics, state):
messages.append(a.finish(s))
return "\n".join(messages)
return Statistic(initialize, update, finish)
[docs]
def mean(metric: Metric) -> Statistic[Tuple[List[Tensor], Tensor], float]:
"""
Create a statistic that takes input pairs (i, g)
and computes the arithmetic mean of metric(i, g) over all given inputs.
:param metric: Mean metric.
:return: Statistic that captures the data
"""
def initialize() -> _MeanValues:
return _MeanValues()
def update(state: _MeanValues, x: Tuple[Tensor, Tensor]) -> None:
x1, x2 = x
state.mean_values.append(metric(x1, x2))
def finish(state: _MeanValues) -> float:
return float(np.mean(state.mean_values))
return Statistic(initialize, update, finish)
[docs]
def mean_text(metric: Metric) -> Statistic[Tuple[List[Tensor], Tensor], str]:
"""
Create a statistic that takes input pairs (i, g)
and computes the arithmetic mean of metric(i, g) over all given inputs and formats the results as text message.
:param metric: Mean metric.
:return: Statistic that captures the data
"""
def initialize() -> _MeanValues:
return _MeanValues()
def update(state: _MeanValues, x: Tuple[Tensor, Tensor]) -> None:
x1, x2 = x
state.mean_values.append(metric(x1, x2))
def finish(state: _MeanValues) -> str:
return f'Arithmetic mean for {len(state.mean_values)} given inputs: {float(np.mean(state.mean_values))}'
return Statistic(initialize, update, finish)