#########################################################
# Copyright (C) 2024 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Christopher Rodrigues
#########################################################
"""
Shared definitions for optimization of requantize nodes.
"""
from typing import Optional, List, TypeVar, Tuple, Union
import numpy as np
from ml_kernels.requantization import BaseRequantization
from afe.ir.defines import DataValue
class _NotFound:
"""A nonce value used to signal that mapping lookup has failed."""
pass
_NotFoundValue = _NotFound()
# A way that a tensor is requantized. This type is used to
# record the fact that this requantization must be computed
# or to look up information related to a requantization that
# must be computed.
# None represents a use of a tensor without any requantization.
[docs]
Need = Optional[BaseRequantization[np.ndarray]]
# Requantizations to perform for one data value.
# Lists are used to hold all needs for a tensor.
[docs]
DataNeeds = DataValue[List[Need]]
# NeedMapping[A] is a mapping from Need to A.
# It associates an A to every way in which a tensor is used. An empty list means the tensor is not used.
# We use a list instead of a dict because BaseRequantization[np.ndarray] is not hashable.
# Objects that are used as keys must not be mutated.
[docs]
NeedMapping = List[Tuple[Need, A]]
[docs]
def need_mapping_empty() -> NeedMapping[A]:
"""
Create an emtpy NeedMapping.
"""
return list()
[docs]
def need_mapping_singleton(k: Need, v: A) -> NeedMapping[A]:
"""
Create a NeedMapping containing a single (k, v) entry.
"""
return [(k, v)]
[docs]
def need_mapping_insert(m: NeedMapping[A], k: Need, v: A) -> None:
"""
Insert (k, v) into m. Replaces any existing entry for k.
"""
# Replace existing item if it exists
for i, (k2, _) in enumerate(m):
if k == k2:
m[i] = k, v
return
# Else, append new item
m.append((k, v))
def _need_mapping_find_impl(m: NeedMapping[A], k: Need) -> Union[A, _NotFound]:
for k2, v in m:
if k == k2:
return v
return _NotFoundValue
[docs]
def need_mapping_get(m: NeedMapping[A], k: Need) -> Optional[A]:
"""
Get the value associated with need k, or None if k is not found.
"""
match _need_mapping_find_impl(m, k):
case _NotFound():
return None
case x:
return x
[docs]
def need_mapping_find(m: NeedMapping[A], k: Need) -> A:
"""
Get the value associated with need k. Raise an exception if it is not found.
"""
match _need_mapping_find_impl(m, k):
case _NotFound():
raise KeyError("Cannot find value for key " + str(k))
case x:
return x