Source code for afe.common_utils

#########################################################
# Copyright (C) 2021 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Joey Chou
#########################################################
"""
This file contains functions that can be used in all AFE source code
"""
from fnmatch import fnmatch
from enum import Enum
from dataclasses import dataclass
from typing import Optional, Tuple, List, Union, Set
import subprocess
import os

from sima_utils.logging import sima_logger


[docs] class EnumHelper(Enum): """ Overload Enum's _missing_ method to print out more informative error message. """ @classmethod
[docs] def values(cls) -> List[str]: """List out the supported values. """ return list(map(lambda c: c.value, cls))
@classmethod def _missing_(cls, value): err_msg = f"'{value}' is not a valid {cls.__name__}. Only support {cls.values()}" raise ValueError(err_msg)
[docs] class Singleton(type): _instances = {} def __call__(cls, *args, **kwargs): if cls not in cls._instances: cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) return cls._instances[cls]
[docs] def get_index_from_node_name(name: str, splitter="_") -> Optional[int]: """ Use the postfix after the last splitter of the given name. If the postfix is an integer, return it as the index. Parameters ---------- :param name: str. Name of the node. :param splitter: str. Used to locate the postfix. Return ------ :return: Optional[int]. Return None if there is no index in the given node name else return the index in int. """ postfix = name[name.rfind(splitter) + 1:] return None if not postfix.isdigit() else int(postfix)
[docs] def unroll_tuple_range(tuple_range: Tuple[int, int]) -> List[int]: """ Unroll a tuple of two integers to a range of integers where the lower bound is the first integer of the tuple and the upper bound is the second integer of the tuple. Parameters ---------- :param tuple_range: Tuple[int, int]. Tuple of two integers used to generate a list of continuous integers. Return ------ :return: List[int]. List of continuous integers """ assert len(tuple_range) == 2 \ and isinstance(tuple_range[0], int) \ and isinstance(tuple_range[1], int), \ f"To unroll a tuple range, the range has to be a tuple of two integers. Got {tuple_range}" return [i for i in range(tuple_range[0], tuple_range[1] + 1)]
[docs] def parse_indices(indices_list: List[Union[int, Tuple[int, int]]] ) -> List[int]: """ Used to parse a list of indices. Each element in the list can be: 1. A single integer 2. A tuple contains 2 integers where the left one represent the lower bound and the right one represents the upper bound. Both lower bound and upper bound are inclusive. Example ------- .. code-block:: python indices_list = [2, (11, 14), 7] decoded_indices_list = parse_indices(indices_list) # decoded_indices_list = [2, 7, 11, 12, 13, 14] Parameters ---------- :param indices_list: List[int, Tuple[int, int]]. List of indices contains element in either an integer or a Tuple of two integers. Return ------ :return: List[int]. List of indices. Each element is a integer """ indices = set() for ele in indices_list: if isinstance(ele, int): indices.add(ele) elif isinstance(ele, tuple): indices |= set(unroll_tuple_range(ele)) else: raise ValueError("Element in the indices list must be an integer or a tuple of" f" two integers. Got {ele}") indices = list(indices) indices.sort() return indices
[docs] def generate_node_name_patterns(patterns: List[Union[str, int, Tuple[int, int]]] ) -> List[str]: """ Given patterns in List[Union[str, int, Tuple[int, int]]], generate node name patterns using the different types of pattern as below: 1. str: Append to the output list. Support wildcard. 2. int: Append the *_{number} to the output list. 3. Tuple[int, int]: Unroll the tuple of two integers to a range of integers where the lower bound is the first integer of the tuple and the upper bound is the second integer of the tuple. Each integer will be converted to *_{number} string and append to the output list. The output list contains a list of str. Please check the example below Example ------- .. code-block:: python patterns = [2, (11, 14), "*_conv*", "3", "10"] str_patterns = generate_str_patterns(patterns) # str_patterns = ["*_2", "*_11", "*_12", "*_13", "*_14", "*_conv*", "*_3", "*_10"] Parameters ---------- :param patterns: List[str, int, Tuple[int, int]] Return ----- :return: List[str]. List of str patterns. """ int_patterns: List[int] = [] output_patterns: List[str] = [] for p in patterns: if isinstance(p, tuple): int_patterns += unroll_tuple_range(p) elif isinstance(p, int): int_patterns.append(p) elif isinstance(p, str): output_patterns.append(p) else: raise ValueError(f"Unknown pattern type {type(p)} in the given patterns {patterns}" "Only support [str, int, Tuple[int, int]] pattern") # Convert int patterns to str patterns with wildcard output_patterns += [f"*_{i}" for i in int_patterns] return output_patterns
[docs] def search_matched_node_names(node_names: List[str], patterns: List[Union[str, int, Tuple[int, int]]], excluded_patterns: Optional[List[Union[str, int, Tuple[int, int]]]] = None ) -> Set[str]: """ Given a list of node names and targeted patterns in List[Union[str, int, Tuple[int, int]]] format. Generate a node name set using the different types of pattern as below: 1. str: Support wildcard. 2. int: Find node with name contains *_{number}. 3. Tuple[int, int]: Unroll the tuple of two integers to a range of integers where the lower bound is the first integer of the tuple and the upper bound is the second integer of the tuple. Each integer will be converted to *_{number} string that will be used as 2 above Return a set of matched node names Example ------- The example will search nodes with indices equal to [2, 3, 10, 11, 12, 13, 14] and all the node contains "conv" in the node name. Because the excluded_patterns is assigned so the nodes with "conv2d_transpose" in the node name will be excluded. .. code-block:: python patterns = [2, (11, 14), "*conv*", "3", "10"] excluded_patterns = ["*conv2d_transpose*"] matched_node_name_set = search_matched_node_names(net, patterns, excluded_patterns) Parameters ---------- :param node_names: List[str]. List of node names and will be applied by the pattern matching using the patterns input argument. :param patterns: List[Union[str, int, Tuple[int, int]]]. List of patterns. :param excluded_patterns: Optional[List[Union[str, int, Tuple[int, int]]]]. Default is None. List of patterns that will be excluded from the pattern matching. Node names that contains these patterns will be excluded from the matched patterns. Return ------ :return: Set[str]. Set of node names that contain the given patterns but not the excluded patterns """ str_patterns: List[str] = generate_node_name_patterns(patterns) matched_node_names: Set[str] = set() for name in node_names: for p in str_patterns: if fnmatch(name, p): matched_node_names.add(name) break if excluded_patterns is not None: excluded_node_names = search_matched_node_names( list(matched_node_names), patterns=excluded_patterns) matched_node_names -= excluded_node_names return matched_node_names
[docs] def get_afe_git_commit_id() -> Optional[str]: """ Get the last AFE git commit ID if running out of a git repository. Return None if the ID can't be found. If running an installed package, no attempt is made to get the commit ID. :return: commit ID as string, if it can't get ID returns None. """ if is_installed(): return None commit_id = subprocess.run(['git', 'rev-parse', 'HEAD'], check=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=os.path.dirname(__file__)) # If the command failed, cannot get the commit ID. Return None. if commit_id.returncode != 0: return None return commit_id.stdout.decode('ascii').strip()
[docs] def get_filetype_from_directory(dir_path: str, filetype: str) -> str: """ Looks through the files inside a directory and returns the file path with the matching filetype. Raises an error if multiple files match the filetype EG: in './example_dir' find '.json' file path """ assert os.path.isdir(dir_path), "Error path supplied must be a path to a directory" dir_files = os.listdir(dir_path) assert any(filetype in file_name for file_name in dir_files), \ "Error no filetype ({}) in directory ({})".format(filetype, dir_path) assert sum([filetype in file_name for file_name in dir_files]) == 1, \ "Error. Multiple files have the filetype ({}) in directory ({})".format(filetype, dir_path) for file_name in dir_files: if filetype in file_name: return os.path.join(dir_path, file_name) raise FileNotFoundError("Error no filetype ({}) in directory ({})".format(filetype, dir_path))
def _compute_is_installed() -> bool: """ Return true if this file appears to be running from an installed package (not a source directory). Callers should call is_installed instead of this function. """ base_path, filename = os.path.split(__file__) base_path, dir1 = os.path.split(base_path) assert dir1 == 'afe' and filename == 'common_utils.py' # If setup.py is found, this is the source tree, so this file is not installed. # Otherwise, assume it is installed. return not os.access(os.path.join(base_path, 'setup.py'), os.F_OK) # Memoized result of compute_is_installed _IS_INSTALLED: Optional[bool] = None
[docs] def is_installed() -> bool: """ Return true if this file appears to be running from an installed package (not a source directory). """ global _IS_INSTALLED if _IS_INSTALLED is None: _IS_INSTALLED = _compute_is_installed() return _IS_INSTALLED
@dataclass
[docs] class ARMRuntime: """ Host and port of TVM RPC server where ARM code will be run in tests """
[docs] hostname: str
[docs] port: int
def _compute_arm_target() -> Optional[ARMRuntime]: """ Examine system environment to find the ARM runtime, if any. """ try: arm_server_string = os.environ['TVM_ARM_SERVER'] except KeyError: sima_logger.sima_log_info("No ARM environment found; ARM testing will be disabled") return None try: arm_host, arm_port_str = arm_server_string.rsplit(':', maxsplit=1) arm_port = int(arm_port_str) except (TypeError, ValueError): sima_logger.sima_log_error("Cannot parse ARM server string; ARM testing will be disabled") return None sima_logger.sima_log_info(f"Using RPC server {arm_host}:{arm_port} for ARM") return ARMRuntime(arm_host, arm_port) # Memoized ARM target configuration. # () means the target was not computed. # None means there is no ARM support. _ARM_RUNTIME: Union[None, Tuple[()], ARMRuntime] = ()
[docs] def get_arm_runtime() -> Optional[ARMRuntime]: """ Return the ARM target information if ARM target support is found on the system, or None otherwise. """ global _ARM_RUNTIME # Initialize if it has not been computed yet if _ARM_RUNTIME == (): _ARM_RUNTIME = _compute_arm_target() return _ARM_RUNTIME