Source code for afe.ir.transform.requantization_hoisting.defines

#########################################################
# Copyright (C) 2024 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Christopher Rodrigues
#########################################################
"""
Shared definitions for optimization of requantize nodes.
"""
from typing import Optional, List, TypeVar, Tuple, Union

import numpy as np

from ml_kernels.requantization import BaseRequantization

from afe.ir.defines import DataValue

[docs] A = TypeVar('A')
class _NotFound: """A nonce value used to signal that mapping lookup has failed.""" pass _NotFoundValue = _NotFound() # A way that a tensor is requantized. This type is used to # record the fact that this requantization must be computed # or to look up information related to a requantization that # must be computed. # None represents a use of a tensor without any requantization.
[docs] Need = Optional[BaseRequantization[np.ndarray]]
# Requantizations to perform for one data value. # Lists are used to hold all needs for a tensor.
[docs] DataNeeds = DataValue[List[Need]]
# NeedMapping[A] is a mapping from Need to A. # It associates an A to every way in which a tensor is used. An empty list means the tensor is not used. # We use a list instead of a dict because BaseRequantization[np.ndarray] is not hashable. # Objects that are used as keys must not be mutated.
[docs] NeedMapping = List[Tuple[Need, A]]
[docs] def need_mapping_empty() -> NeedMapping[A]: """ Create an emtpy NeedMapping. """ return list()
[docs] def need_mapping_singleton(k: Need, v: A) -> NeedMapping[A]: """ Create a NeedMapping containing a single (k, v) entry. """ return [(k, v)]
[docs] def need_mapping_insert(m: NeedMapping[A], k: Need, v: A) -> None: """ Insert (k, v) into m. Replaces any existing entry for k. """ # Replace existing item if it exists for i, (k2, _) in enumerate(m): if k == k2: m[i] = k, v return # Else, append new item m.append((k, v))
def _need_mapping_find_impl(m: NeedMapping[A], k: Need) -> Union[A, _NotFound]: for k2, v in m: if k == k2: return v return _NotFoundValue
[docs] def need_mapping_get(m: NeedMapping[A], k: Need) -> Optional[A]: """ Get the value associated with need k, or None if k is not found. """ match _need_mapping_find_impl(m, k): case _NotFound(): return None case x: return x
[docs] def need_mapping_find(m: NeedMapping[A], k: Need) -> A: """ Get the value associated with need k. Raise an exception if it is not found. """ match _need_mapping_find_impl(m, k): case _NotFound(): raise KeyError("Cannot find value for key " + str(k)) case x: return x