Source code for afe.backends.backend_checker

#########################################################
# Copyright (C) 2021 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Joey Chou
#########################################################
"""
This file contains classes for how to implement a checker class for each IR and
how to execute checkers. There are 2 classes:

    1.  ExprCheckInfo - A class that acts as an interface for each check function.
        Developer has to make sure a check function takes a CheckerInterface object as
        its input arguments. When implementing an IR checker, arguments needed to perform
        all the check functions must be present in a CheckerInterface object. The
        CheckerInterface instance should be created for each operator and the backend checker's
        check() method should be called to determine if an operator is supported by a certain
        backend.

    2.  BaseChecker - An abstract class that is the base class for checkers.
        Developer should inherit this class and implement a new checker for each of
        supported backends. When implementing a new checker developer needs to implement a
        checker function for each supported operator and call it from the checker's check()
        method. The checker function which is to be called is determined by the CheckerInterface's
        'name' attribute. The checker function should determine if the backend supports an
        operator using CheckerInterface's 'attrs' and 'input_shapes'.

    Example:
        class MLACheckers(BaseChecker):
            _checkers_name: str = "MLA Checkers"
            _backend = Backend.MLA
            _predicate = paccept  # Accept everything
"""
from dataclasses import dataclass
from typing import List, Tuple, Callable, Sequence, Dict
from attrdict import AttrDict
from afe.backends import Backend

# TVM expression attributes, converted to an AttrDict
[docs] CheckerAttr = AttrDict
[docs] CheckerInputShapes = List[Tuple[int, ...]]
@dataclass
[docs] class Rejections: """ Holds reasons why operators were rejected by a backend checker. This is used by tests that verify checker decisions. error_codes[i] is the list of error codes produced for the i_th operator that was examined. The list is empty if the operator was accepted. """
[docs] error_codes: List[List[int]]
@dataclass
[docs] class Decision: """ A decision from the checker about one expression. This is an algebraic data type. """ pass
@dataclass
[docs] class Accept(Decision): """ A decision to accept an expression, that is, assign it to the selected backend. """ pass
@dataclass
[docs] class Reject(Decision): """ A decision to reject an expression, that is, not assign it to the selected backend. The error codes give the reason it is rejected. """
[docs] error_codes: List[int]
[docs] def decision_from_bool(b: bool, error_code: int) -> Decision: """ Make a Decision from a boolean value. :param b: Whether to return Accept or Reject. :param error_code: Which error code to include if Reject is returned. Error codes should match the numbers in unsupported_code_codebook. :return: Decision made from b and error_code. """ return Accept() if b else Reject([error_code])
@dataclass
[docs] class ExprCheckInfo: """ Properties of one Relay IR expression that are relevant to backend assignment. An ExprCheckInfo is passed to a backend checker for deciding whether the expression can run on that backend. :param name: Name of the expression's operator. :param attrs: A list of Relay operator attributes. If the expression is a composite operator, the list has one item for each call in the composite operator's body. Otherwise, it has a single item, which is the expression's attribute. :param input_shapes: Shapes of the expression's input tensors. :param is_constant: List of boolean values providing information whether certain input is a constant. :param: idx: The expression's index in the graph's topological order. """
[docs] name: str
[docs] attrs: List[CheckerAttr]
[docs] input_shapes: CheckerInputShapes
[docs] is_constant: List[bool]
[docs] idx: int
# A predicate, used for deciding whether an expression can be assigned to a backend
[docs] Predicate = Callable[[ExprCheckInfo], Decision]
[docs] def pany(ps: Sequence[Predicate]) -> Predicate: """ Make a predicate that returns Accept if any predicate in the list returns Accept. When Reject is returned, it contains the concatenation of all predicates' error codes. """ def pany_predicate(s: ExprCheckInfo) -> Decision: error_codes = [] for p in ps: r = p(s) if isinstance(r, Accept): return r assert isinstance(r, Reject) error_codes.extend(r.error_codes) return Reject(error_codes) return pany_predicate
[docs] def pall(ps: Sequence[Predicate]) -> Predicate: """ Make a predicate that returns Accept if all predicates in the list return Accept. When Reject is returned, it contains the concatenation of all predicates' error codes. All predicates in the list are evaluated. """ def pall_predicate(s: ExprCheckInfo) -> Decision: error_codes = [] accept = True for p in ps: r = p(s) if isinstance(r, Reject): accept = False error_codes.extend(r.error_codes) # No action for Accept return Accept() if accept else Reject(error_codes) return pall_predicate
# Predicate that accepts anything
[docs] paccept = pall([])
[docs] class BaseChecker: """ A way to decide whether a given expression can be executed on a selected backend. This class should be implemented by subclassing and overriding the class variables. It is not meant to be instantiated. :param _checkers_name: str. Name of the checker's factory class. :param _backend: The type of the Backend. :param _predicate: Predicate that decides whether an operator can execute on the backend. """ _checkers_name: str = "" _backend: Backend = None _predicate: Predicate @classmethod
[docs] def get_backend(cls) -> Backend: """ Return the backend for which this checker class makes decisions. """ return cls._backend
@classmethod
[docs] def check(cls, args: ExprCheckInfo) -> Decision: """ Examine properties of an expression to decide whether the expression can be assigned to this checker's associated backend. """ return cls._predicate(args)