#########################################################
# Copyright (C) 2024 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Christopher Rodrigues
#########################################################
"""
Analysis of statistics on tensors.
"""
import dataclasses
from typing import Callable, Generic, TypeVar, Type, Any, Tuple, List
import numpy as np
_I = TypeVar('_I')
_O = TypeVar('_O')
_S = TypeVar('_S')
_A = TypeVar('_A')
[docs]
class StatisticInstance(Generic[_I, _O]):
"""
A collector of statistics on a stream of inputs.
With an instance of the class, inputs are supplied
repeatedly by calling update, then the computed result
is read by calling finish.
"""
[docs]
def update(self, x: _I) -> None:
raise NotImplementedError("Method is abstract")
[docs]
def finish(self) -> _O:
raise NotImplementedError("Method is abstract")
@dataclasses.dataclass(frozen=True)
[docs]
class Statistic(Generic[_I, _O]):
"""
A method of computing a property over a stream of inputs.
The property is normally a reduction, e.g., the average of
the mean squared errors of all inputs.
"""
[docs]
initialize: Callable[[], _S]
[docs]
update: Callable[[_S, _I], None]
[docs]
finish: Callable[[_S], _O]
[docs]
def instantiate_type(self) -> Type[StatisticInstance[_I, _O]]:
# Bring members into local scope before shadowing self
initialize = self.initialize
update = self.update
finish = self.finish
class ConstructedStatisticInstance(StatisticInstance[_I, _O]):
def __init__(self):
self._state = initialize()
def update(self, x: _I) -> None:
update(self._state, x)
def finish(self) -> _O:
return finish(self._state)
return ConstructedStatisticInstance
[docs]
def comap(self, f: Callable[[_A], _I]) -> "Statistic[_A, _O]":
"""
Map a function over the input stream.
:param f: Function to apply to inputs
:return: Transformed Statistic
"""
def new_update(state: Any, x: _A) -> None:
self.update(state, f(x))
return dataclasses.replace(self, update=new_update)
[docs]
def map(self, f: Callable[[_O], _A]) -> "Statistic[_I, _A]":
"""
Map a function over the output.
:param f: Function to apply to output
:return: Transformed Statistic
"""
def new_finish(state: Any) -> _A:
return f(self.finish(state))
return dataclasses.replace(self, finish=new_finish)