Source code for afe.backends.checker_utils

#########################################################
# Copyright (C) 2021 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Joey Chou, Jeffrey Spitz
#########################################################
"""
This file contains utility functions that are used in deciding if an IR
can be supported or not supported in certain backends
"""
# TODO(Joey): Add description for each check function.
import numpy as np
from typing import Union, List, Tuple, Callable, Optional

from attrdict import AttrDict

from afe._tvm._defines import DEFAULT_4D_DATA_LAYOUT
from afe._tvm._utils import is_supported_mla_pool_size
from afe.backends import ExprCheckInfo
from afe.backends.backend_checker import Decision, Accept, Reject, decision_from_bool
from afe.ir.utils import afe_warn, is_mla_supported_einsum_equation


[docs] def find_matching_attr(match: Callable[[AttrDict], bool], attrs: List[AttrDict], *, description: Optional[str] = None) -> AttrDict: """ Find and return the first matching attribute dictionary in the list. Raise a ValueError if there is no matching dictionary. :param match: Function that checks whether an attribute dictionary matches. :param attrs: Attributes to search :param description: Description of what it attempts to match. Used when raising an error. :return: First matching attribute in the list """ for a in attrs: if match(a): return a if description: message = "No attribute found matching " + description else: message = "No attribute found" raise ValueError(message)
def _is_pool_attribute(attrs: AttrDict) -> bool: return 'pool_size' in attrs # TODO: Remove unused codes
[docs] unsupported_code_codebook = { -1: "Unsupported", 0: "Pass", # padding_is_present_and_included 1: "Padding is present but count_include_pad is zero.", # exclude_of_reduce_expression_is_true 2: "Does not support having exclude set to True for the reduce expression", # keepdims_of_reduce_expression_is_true 3: "Does not support keepdims set to False for reduce expressions", # axes_of_reduce_expression_are_nonzero 4: "Zero axis must be included when exclude is set to True; " " when exclude is set to False the zero axis must not be included." " Additionally, axis must not be set to None", # zero_axis_of_input_has_shape_of_1 5: "Zero axis of the input shape must have a value of 1", # input_to_relay_call_is_3_dimensional 6: "Does not support having some of the inputs 3-dimensional", # tuple_concat_axis_is_1_or_last 7: "Does not support concatenating along axes that are not 1 or last(-1)", # Unused # dilations_of_pooling_expression_are_greater_than_1 8: "Does not support dilations of pooling expressions greater than 1", # axis_is_none_or_contains_0 9: "Pool_size dimensions should all be < 128, except in case of global average pooling", 10: "Squeezing along the batch dimension is not supported or axis was set to None", # check_resize_is_supported 11: "Only positive scale factors are supported for resize / upsampling", # check_resize_is_supported 12: "Only linear / bilinear and nearest_neighbor methods are supported for resize / upsampling", # corners_not_aligned 13: "Setting align_corners == True is not supported", # pooling_expression_ceil_mode_is_false 14: "Pooling expression ceil mode set to false", 15: "Mean operation over batch or channel axis is not supported and dimension should be less than 128", 16: "Pool2d explicit padding has incorrect dimension", # output_size_shape_is_1 18: "Output size shape must be 1", # pad_mode_is_constant_and_pad_width_is_zero 19: "Pad mode must be constant and pad width must be 0", # tuple_concat_axis_is_not_0 20: "Axis must not be zero", # input_to_relay_call_is_4_dimensional 21: "Does not support having some of the inputs 4-dimensional", # axes_are_nonzero_or_zero_axis_of_input_has_shape_of_1 22: "Axes are zero and zero axis of input has shape that is not equal to 1" " for the reduce expression", # dilations_of_conv_transpose_expression_are_1 23: "Dilations of transpose convolution (2d or 3d) must be 1", # number_of_groups_is_1 24: "Number of groups for transpose conv2d must be 1", # number_of_groups_equals_number_of_input_channels 25: "Number of groups is not equal to the number of input channels", # input_into_relay_expression_is_5d_or_4d_or_3d 27: "Does not support input that is neither three, four, nor five dimensional", # Unused # Does not support crop and resize in image.resize when roi is not None 28: "image.resize does not support coordinate_transformation_mode to be 'tf_crop_and_resize', roi should be None", # PRelu's alpha axis is the channel axis, as required by the MLA 29: "Prelu is supported only along channel axis", # MLA supports StridedSlice with strides == 1 30: "StridedSlice strides are not equal to 1", # MLA does not support StridedSlice where begin or end on channel axis is not multiple of 16 31: "StridedSlice begin or end is not multiple to 16", # MLA Operators support 4D inputs only 32: "Input is not 4D tensor", 33: "Each input must be a 4D tensor or at most one input can be a scalar constant", 34: "Input must be broadcastable and non Constant", 35: "Strides are not a power of 2", 36: "One or more outputs do not have float32 type", 37: "One or more outputs do not have int8 type", 38: "Input of argmax operator must have size (1,1,C) with C <= 2032", 39: "One or more outputs do not have int32 type", 40: "Reduction must be performed across only axis 3", 41: "Softmax is supported only along channel axis", 42: "Number of einsum inputs are not equal to expected number of inputs", 43: "Unsupported einsum equation", 44: "Depthwise conv2d transpose is supported only if stride is 1 or 2", 45: "Pytorch_half_pixel coordinate transformation mode is not supported when size of the resized tensor is 1", 46: "Unsupported value of strides, only strides in [1, 31] range are supported.", 47: "Unsupported value of dilation, only dilation in [1, 63] range is supported.", 48: "Broadcast to is only supported for 4d output, and only for shapes that are broadcastable.", 49: "Transpose affecting the batch axis is not supported.", 50: "For resizing by nearest neighbor method only scaling by a power of 2 greater than 1 is supported", 51: "Input is not 4D or 5D tensor", 52: "Standalone Variance operator is only supported if it's done along D, H and W axes.", 53: "Only 2D interpolation with linear mode and non-reflection padding is supported." }
[docs] def unsupported_op_code_to_message(code: Union[int, List[int]]) -> Union[str, List[str]]: if isinstance(code, int): return unsupported_code_codebook[code] else: assert isinstance(code, list) and all(isinstance(i, int) for i in code) error_messages = [] for c in code: error_messages.append(unsupported_code_codebook[c]) return error_messages
[docs] def padding_is_present_and_included(args: ExprCheckInfo) -> Decision: """ Return True if the average pooling operator includes all padding values in the averaging operation. This includes the situation when there is no padding. """ attrs = find_matching_attr(_is_pool_attribute, args.attrs) padding = [int(p) for p in attrs.padding] padding_sum = np.sum(padding) return decision_from_bool(attrs.count_include_pad or padding_sum == 0, 1)
[docs] def exclude_of_reduce_expression_is_true(args: ExprCheckInfo) -> bool: """ Return True if the exclude attribute exists and is True. """ attrs = args.attrs[0] if hasattr(attrs, 'exclude'): return attrs.exclude else: afe_warn(f"Cannot interpret reduce expression {attrs}. Lacks attribute exclude. " f"Defaulting to exclude=False") return False
[docs] def supported_mean_operator_axes_and_input_size(args: ExprCheckInfo) -> Decision: """ Return True if reduce axes do not contain batch and channel axis, and input size on a reduce axis is less than 128. """ attrs = args.attrs[0] reduce_axis = list(attrs.axis) input_shape = args.input_shapes[0] batch_axis = 0 channel_axis = len(input_shape) - 1 # Reduce on batch or channel dimension is not supported. if batch_axis in reduce_axis or channel_axis in reduce_axis: return Reject([15]) # Mean with all spatial axes have been rewritten to GlobalAvgPool which has no size limit. # Otherwise, ensure input size of a reduce axis is less than 128 pool_size = tuple( s if i in reduce_axis else 1 for i, s in enumerate(input_shape[1:-1], start=1) ) return decision_from_bool(is_supported_mla_pool_size(pool_size), 15)
[docs] def keepdims_of_reduce_expression_is_true(args: ExprCheckInfo) -> Decision: """ Return True is the keepdims attribute exists and is True """ attrs = args.attrs[0] if hasattr(attrs, 'keepdims'): return decision_from_bool(attrs.keepdims, 3) else: afe_warn(f"Cannot interpret reduce expression {attrs}. Lacks attribute keepdims. " f"Defaulting to keepdims=False") return Reject([3])
[docs] def axes_of_reduce_expression_are_nonzero(args: ExprCheckInfo) -> Decision: """ Note: If `exclude` is True, reduction will be performed on the axes that are NOT in axis instead. So if exclude is true we want axes to include zero instead. Note: If `axes` is None the operation is performed across all axes and thus we return false since we reduce along the 0th axis """ attrs = args.attrs[0] axis = attrs.axis exclude = exclude_of_reduce_expression_is_true(args) if axis is None: return Reject([4]) axis_int_list = [int(a) for a in axis] if exclude: is_nonzero = 0 in axis_int_list else: is_nonzero = 0 not in axis_int_list return decision_from_bool(is_nonzero, 4)
[docs] def zero_axis_of_input_has_shape_of_1(args: ExprCheckInfo) -> Decision: """ Return True if all input tensors' 0 axis (batch dimension) shape is 1 or all inputs are scalar. """ input_shapes = args.input_shapes return decision_from_bool(all(len(shape) == 0 or shape[0] == 1 for shape in input_shapes), 5)
[docs] def dilations_of_pooling_expression_are_greater_than_1(args: ExprCheckInfo) -> Decision: """ Return True if all dilations are 1 """ attrs = find_matching_attr(_is_pool_attribute, args.attrs) dilations = [int(d) for d in attrs.dilation] # TODO: Find a way to emit a RuntimeError because the operator tests expect one. # if np.product(dilations) > 1: # raise RuntimeError(f"We currently do not support dilation > 1 for avg_pool2d. got {attrs.dilation}") return decision_from_bool(np.product(dilations) == 1, 8)
[docs] def check_resize_is_supported(args: ExprCheckInfo) -> Decision: """ Checks if scaling values for h_axis and w_axis are supported for the resize method. If not returns False. The following are supported. 1. Only positive scaling factors are supported. 2. If method is 'linear' or 'bilinear', any scaling factor is supported. However, for 8-bit integers any scaling more than 63 may lead to loss in accuracy. 3. If the method is 'nearest neighbor' a power of two scaling factor is supported. 4. Any other method or scaling factor is not supported. """ attrs = args.attrs[0] input_shape = args.input_shapes[0] layout: str = attrs.layout is_integral_resize: bool = True scaling_h: int = 0 scaling_w: int = 0 if hasattr(attrs, "size"): # Getting input and output image shapes in_h, in_w = input_shape[layout.index("H")], input_shape[layout.index("W")] out_h, out_w = [v.value for v in attrs.size] if out_h <= 0 or out_w <= 0: return Reject([11]) if (out_h >= in_h) and (out_h % in_h == 0) and (out_w >= in_w) and (out_w % in_w == 0): scaling_h = out_h // in_h scaling_w = out_w // in_w else: is_integral_resize = False elif hasattr(attrs, "scale_h") and hasattr(attrs, "scale_w"): if attrs.scale_h < 0 or attrs.scale_w < 0: return Reject([11]) if ( (isinstance(attrs.scale_h, int) or (isinstance(attrs.scale_h, float) and attrs.scale_h.is_integer())) and (isinstance(attrs.scale_w, int) or (isinstance(attrs.scale_w, float) and attrs.scale_w.is_integer())) ): scaling_h = int(attrs.scale_h) scaling_w = int(attrs.scale_w) else: is_integral_resize = False else: raise NotImplementedError(f"Unsupported attrs encountered while checking MLA compatibility" f" {attrs.__class__.__name__}, Only upsampling and image.resize2d are supported") if ( attrs.method in ('linear', 'bilinear') and ( not hasattr(attrs, "coordinate_transformation_mode") or attrs.coordinate_transformation_mode == 'half_pixel' ) ): return Accept() elif attrs.method == "nearest_neighbor": if not is_integral_resize: return Reject([50]) # For nearest neighbor method of resizing we can scale by any positive power of 2. log_scaling_h: float = np.log2(scaling_h) log_scaling_w: float = np.log2(scaling_w) if ( not log_scaling_h.is_integer() or not log_scaling_w.is_integer() or (scaling_h == 1 and scaling_w == 1) ): return Reject([50]) return Accept() return Reject([12])
[docs] def pytorch_half_pixel_is_half_pixel(args: ExprCheckInfo) -> Decision: """ Check if pytorch_half_pixel coordinate transformation mode is same as half_pixel. It will be same if the size of the resized tensor is greater than 1. """ attrs = args.attrs[0] is_same = True if attrs.coordinate_transformation_mode == 'pytorch_half_pixel': out_h, out_w = [v.value for v in attrs.size] is_same = out_h > 1 and out_w > 1 return decision_from_bool(is_same, 45)
[docs] def corners_not_aligned(args: ExprCheckInfo) -> Decision: """ Return True if align_corners is False """ # we only support align_corners == False attrs = args.attrs[0] if hasattr(attrs, "align_corners"): not_aligned = not bool(attrs.align_corners) elif hasattr(attrs, "coordinate_transformation_mode"): not_aligned = not attrs.coordinate_transformation_mode == "align_corners" else: raise NotImplementedError(f"Unsupported attrs encountered while checking MLA compatibility" f" {attrs.__class__.__name__}, Only upsampling and image.resize2d are supported") return decision_from_bool(not_aligned, 13)
[docs] def pooling_expression_ceil_mode_is_false(args: ExprCheckInfo) -> Decision: """ Return True if ceil_mode is False """ # TODO: Find a way to emit a RuntimeError because the operator tests expect one. # if attrs.ceil_mode: # raise RuntimeError("We currently do not support ceil_mode == True for AvgPool/MaxPool. " # "Make sure the SetCeilModeToFalseForNDPooling pass is enabled when importing the model") attrs = find_matching_attr(_is_pool_attribute, args.attrs) return decision_from_bool(not attrs.ceil_mode, 14)
def _is_global_pool(input_shape: tuple[int, ...], pool_size: tuple[int, ...]) -> bool: return pool_size == input_shape[1:-1]
[docs] def supported_pool_size(args: ExprCheckInfo) -> Decision: attrs = find_matching_attr(_is_pool_attribute, args.attrs) pool_size = tuple(attrs.pool_size) input_shape = args.input_shapes[0] return decision_from_bool(is_supported_mla_pool_size(pool_size) or _is_global_pool(input_shape, pool_size), 9)
[docs] def pool_explicit_padding_incorrect_dimension(args: ExprCheckInfo) -> Decision: """ Reject if padding is given for the wrong number of dimensions. """ attrs = find_matching_attr(_is_pool_attribute, args.attrs) if len(attrs.padding) in [2, 4, 6]: return Accept() # TODO: Find a way to emit a RuntimeError because the operator tests expect one. # with the following message: # f"Cannot calculate explicit 2D padding given padding of ({len(attrs.padding)}) dimensions" return Reject([16])
def _normalize_reduce_axis_size_4(axis: int) -> int: """ Convert an axis to a positive number. Negative axis is counted from the end. """ return 4 + axis if axis < 0 else axis
[docs] def axis_is_none_or_contains_0(args: ExprCheckInfo) -> Decision: attrs = args.attrs[0] return decision_from_bool(attrs.axis is None or 0 in attrs.axis, 10)
[docs] def reduce_axis_is_3(args: ExprCheckInfo) -> Decision: """ Accept if the reduce operator performs reduction over axis 3 (the channel axis). The reduction axis is determined from the input shape and the attributes 'axis' and 'exclude'. """ attrs = args.attrs[0] assert len(args.input_shapes) == 1, "Reduce operator must have a single input" input_shape, = args.input_shapes if len(input_shape) != 4: return Reject([40]) if attrs.axis is None: # When axis is None, it means to reduce over all axes axis = {0, 1, 2, 3} elif isinstance(attrs.axis, int): # Single axis axis = {_normalize_reduce_axis_size_4(attrs.axis)} else: # List of axes axis = set(_normalize_reduce_axis_size_4(a) for a in attrs.axis) if attrs.exclude: # Invert the sense of axis axis = {0, 1, 2, 3} - axis # When a dimension has size 1, reducing over the axis has the same effect # as not reducing over it for index, dim_size in enumerate(input_shape): if dim_size == 1: axis.discard(index) return decision_from_bool(axis == {3}, 40)
[docs] def prelu_axis_is_3(args: ExprCheckInfo) -> Decision: """ Return True if the first attribute's axis is 3. This should be used for checking the composite prelu_const operator. It checks that the PRelu's alpha axis is the channel axis, as required by the MLA. """ attrs = args.attrs[0] return decision_from_bool(attrs.axis == 3, 29)
[docs] def softmax_axis_is_3(args: ExprCheckInfo) -> Decision: """ Return True if the first attribute's axis is 3 or -1. This should be used for checking the composite softmax_const operator. It checks that the Softmax axis is the channel axis, as required by the MLA. """ attrs = args.attrs[0] return decision_from_bool(attrs.axis == 3 or attrs.axis == -1, 41)
[docs] def output_size_shape_is_1(args: ExprCheckInfo) -> Decision: """ Return True if the output_size attribute is not None and output_size along all dimension are 1 """ attrs = args.attrs[0] shape_is_1 = attrs.output_size is not None and np.prod([int(os) for os in attrs.output_size]) == 1 return decision_from_bool(shape_is_1, 18)
[docs] def pad_mode_is_constant_and_pad_width_is_0(args: ExprCheckInfo) -> Decision: """ Return True if the padding's pad_mode is 'constant' and the pad_width along all dimensions are 0 """ attrs = args.attrs[0] # Make sure the pad_width is 0 in all dimension is_zero_pad_width = np.all([[p == 0 for p in pad] for pad in attrs.pad_width]) return decision_from_bool(attrs.pad_mode == 'constant' and is_zero_pad_width, 19)
[docs] def check_transformation_mode_is_other_than_tf_crop_resize(args: ExprCheckInfo) -> Decision: """ Does not support when coordinate_transformation_mode in image.resize is 'tf_crop_and_resize' """ attrs = args.attrs[0] return decision_from_bool(not attrs.coordinate_transformation_mode == 'tf_crop_and_resize', 28)
[docs] def tuple_concat_axis_is_not_0(args: ExprCheckInfo) -> Decision: """ Return True if the axis of tuple_concat is not 0 (not batch dimension) """ # TODO: Find a way to emit a RuntimeError because the operator tests expect one. # if attrs.axis == 0: # raise NotImplementedError("Concatenate does not support axis along batch(0)") # NOTE: Only to be used by tuple_concat_checker attrs = args.attrs[0] return decision_from_bool(attrs.axis != 0, 20)
[docs] def axes_are_nonzero_or_zero_axis_of_input_has_shape_of_1(args: ExprCheckInfo) -> Decision: d = axes_of_reduce_expression_are_nonzero(args) or zero_axis_of_input_has_shape_of_1(args) return decision_from_bool(d, 22)
[docs] def dilations_of_conv_transpose_expression_are_1(args: ExprCheckInfo) -> Decision: """ Return True if dilations of 2d/3d transpose convolution are all equal to 1 """ dilations = [] for attrs in args.attrs: if hasattr(attrs, 'dilation'): dilations = [int(d) for d in attrs.dilation] # TODO: Find a way to emit a RuntimeError because the operator tests expect one. # if np.product(dilations) > 1: # raise ValueError("Error: We do not support dilations in Conv2DTranspose that are greater than 1") return decision_from_bool(np.product(dilations) == 1, 23)
[docs] def depthwise_or_number_of_groups_is_1(args: ExprCheckInfo) -> Decision: d = number_of_groups_is_1(args) or number_of_groups_equals_number_of_input_channels(args) return decision_from_bool(d, 24)
[docs] def number_of_groups_is_1(args: ExprCheckInfo) -> Decision: """ Return True if the groups attribute is 1 """ groups = 0 for attrs in args.attrs: if hasattr(attrs, 'groups'): groups = attrs.groups return decision_from_bool(groups == 1, 24)
[docs] def stride_is_power_of_2_upto_16(args: ExprCheckInfo) -> Decision: """ MLA HW only support stride in the given range: [1, 2, 4, 8, 16] for transposed convolution. Return True if stride value is in that range. """ strides = () for attrs in args.attrs: if hasattr(attrs, 'strides'): strides = attrs.strides for val in strides: if val in (1, 2, 4, 8, 16): continue else: return Reject([35]) return Accept()
[docs] def stride_is_1_or_2(args: ExprCheckInfo) -> Decision: """ MLA HW only support stride as 1 or 2 for depthwise transposed convolution. Return True if stride value is 1 or 2. """ strides = () for attrs in args.attrs: if hasattr(attrs, 'strides'): strides = attrs.strides for val in strides: if val in (1, 2): continue else: return Reject([44]) return Accept()
[docs] def number_of_groups_equals_number_of_input_channels(args: ExprCheckInfo) -> Decision: """ Special case for depthwise conv2d_transpose. Return True if the total groups number in a decomposed conv2d_transpose is same as input channel number. This means the operator is a depthwise operator """ # Get groups number and data_layout groups = 0 data_layout = None input_data_layout = None for attr in args.attrs: if hasattr(attr, "groups"): # Increase groups number when we encouter a conv2d_transpose attributes groups += 1 if data_layout is None and hasattr(attr, "data_layout"): # Get the data_layout from the conv2d_transpose attributes data_layout = attr.data_layout if input_data_layout is None and hasattr(attr, "src_layout"): # Get the input data_layout if there is any layout_transform # in the MergeComposite pattern. The first layout_transform # will is the one tranposing the data layout input_data_layout = attr.src_layout assert input_data_layout or data_layout, "No convolution operators found" input_shapes = args.input_shapes[0] # If the layout_transform exists, use the first layout_transform src_layout # as the final data_layout for the group conv2d_transpose. Otherwise use # the data_layout in the conv2d_transpose channels = input_shapes[input_data_layout.index("C")] if input_data_layout \ else input_shapes[data_layout.index("C")] # Support group conv2d_transpose if groups number equal to the input channel return decision_from_bool(groups == channels, 25)
[docs] def strided_slice_stride_is_1(args: ExprCheckInfo) -> Decision: """ Return True if strided slice strides are all equal to 1 """ attrs = args.attrs[0] strides = [int(s) for s in attrs.strides] return decision_from_bool(np.product(strides) == 1, 30)
[docs] def input_is_4d(args: ExprCheckInfo) -> Decision: """ Return True if all inputs are 4D tensors. """ return decision_from_bool(all(len(shape) == 4 for shape in args.input_shapes), 32)
[docs] def input_is_4d_or_5d(args: ExprCheckInfo) -> Decision: """ Return True if all inputs are either 4D or 5D tensors. """ return decision_from_bool(all(len(shape) == 4 or len(shape) == 5 for shape in args.input_shapes), 51)
[docs] def all_tensors_are_float32(args: ExprCheckInfo) -> Decision: """ Return True if all output types specified in the "output_types" attribute are float32. """ return decision_from_bool(all(attr_dict['output_types'][0] == 'float32' for attr_dict in args.attrs if 'output_types' in attr_dict), 36)
[docs] def all_tensors_are_int8(args: ExprCheckInfo) -> Decision: """ Return True if all output types specified in the "output_types" attribute are int8. """ return decision_from_bool(all(attr_dict['output_types'][0] == 'int8' for attr_dict in args.attrs if 'output_types' in attr_dict), 37)
[docs] def all_tensors_are_int32(args: ExprCheckInfo) -> Decision: """ Return True if all output types specified in the "output_types" attribute are int32. """ return decision_from_bool(all(attr_dict['output_types'][0] == 'int32' for attr_dict in args.attrs if 'output_types' in attr_dict), 39)
[docs] def strided_slice_on_chanel_axis_is_mul_of_16(args: ExprCheckInfo) -> Decision: """ Return True if slice begin or end on channel axis is multiple of 16. End of last part must not be multiple of 16. """ # It is being assumed that data layout of slice op at this level is "NHWC". data_layout = "NHWC" attrs = args.attrs[0] axes = attrs.axes if attrs.axes is not None else [i for i in range(len(args.input_shapes[0]))] channel_axes = None for i, axis in enumerate(axes): if axis == data_layout.index("C"): channel_axes = i if channel_axes is not None: if attrs.begin[channel_axes] % 16 != 0: return Reject([31]) if attrs.end[channel_axes] != args.input_shapes[0][3] and \ attrs.end[channel_axes] % 16 != 0: return Reject([31]) return Accept()
[docs] def binary_operator_have_same_input_shapes_or_one_scalar_input(args: ExprCheckInfo) -> Decision: """ Returns True if both inputs are 4D with the same shape or exactly one input is a scalar """ def _is_scalar_shape(sh: Tuple[int, ...]) -> bool: return np.prod(sh) == 1 input_shapes = args.input_shapes if len(input_shapes) != 2: # If it's a composite binary operator and one input was constant, the composite operator # has only one input. The constant's shape is not available here. return Reject([33]) lhs_shape, rhs_shape = input_shapes is_constant = args.is_constant if all(is_constant): return Reject([33]) if any(is_constant): input_shape, constant_shape = (lhs_shape, rhs_shape) if not is_constant[0] else (rhs_shape, lhs_shape) return decision_from_bool( ((len(input_shape) == 4 or len(input_shape) == 5) and (input_shape == constant_shape or _is_scalar_shape(constant_shape))), 33) else: return decision_from_bool( (len(lhs_shape) == 4 or len(lhs_shape) == 5) and lhs_shape == rhs_shape, 33)
def _supported_broadcast(lhs_shape: Tuple[int, ...], rhs_shape: Tuple[int, ...], left_to_right: bool) -> bool: """ Returns True if two shapes are broadcastable. MLA follows the same broadcasting rules as NumPy, i.e. it requires all tensor dimension to be compatible. Two dimensions are compatible when they are equal, or one of them is 1. Broadcasting batch dimension is not supported. If left_to_right is True, broadcast is one way from left shape to right shape; otherwise, two ways. """ def _shape_is_same(axis: int): return lhs_shape[axis] == rhs_shape[axis] def _shape_is_1(axis: int): if left_to_right: return lhs_shape[axis] == 1 else: return lhs_shape[axis] == 1 or rhs_shape[axis] == 1 for axis in range(1, len(lhs_shape)): if not (_shape_is_same(axis) or _shape_is_1(axis)): return False return True
[docs] def binary_operator_inputs_are_broadcastable(args: ExprCheckInfo) -> Decision: """ Returns True if both inputs are broadcastable. """ input_shapes = args.input_shapes if len(input_shapes) != 2: # If it's a composite binary operator and one input was constant, the composite operator # has only one input. The constant's shape is not available here. return Reject([33]) lhs_shape, rhs_shape = input_shapes # Broadcasting will not extend number of dimensions if not (len(lhs_shape) == len(rhs_shape)): return Reject([34]) if _supported_broadcast(lhs_shape, rhs_shape, left_to_right=False): return Accept() else: return Reject([34])
[docs] def check_number_of_einsum_inputs(num_inputs: int) -> Callable[[ExprCheckInfo], Decision]: def _check_fn(args: ExprCheckInfo) -> Decision: """ Accept if number of inputs wrapped in a tuple match num_inputs argument. """ return decision_from_bool(len(args.input_shapes) == num_inputs, 42) return _check_fn
[docs] def einsum_equation_is_supported(args: ExprCheckInfo) -> Decision: """ Accept if einsum equation is representing a batch matmul operation. """ attrs = args.attrs[0] decision = ('equation' in attrs and is_mla_supported_einsum_equation(attrs.equation, data_layout="NHWC")) return decision_from_bool(decision, 43)
[docs] def check_strides_in_conv(args: ExprCheckInfo) -> Decision: """ Accept if the value of strides is in range [1, 31]. """ attrs = args.attrs[0] strides = attrs['strides'] for val in strides: if 1 <= val <= 31: continue else: return Reject([46]) return Accept()
[docs] def broadcast_to_output_is_4d_and_broadcastable(args: ExprCheckInfo): """ Accept if broadcast shapes are broadcastable. """ output_shape = args.attrs[0]['shape'] input_shape = args.input_shapes[0] if len(output_shape) != 4: return Reject([48]) if _supported_broadcast(input_shape, output_shape, left_to_right=True): return Accept() else: return Reject([48])
[docs] def check_transpose_affecting_batch_axis(args: ExprCheckInfo): """ Accept if transpose is not affecting the batch_axis (axes[0] == 0). """ axes = args.attrs[0]['axes'] return Reject([49]) if axes[0] != 0 else Accept()
[docs] def supported_variance_axis(args: ExprCheckInfo): """ Accept if axis is all spatial dimensions. """ axis = tuple(args.attrs[0]['axis']) input_shape = args.input_shapes[0] supported_axes = tuple(range(1, len(input_shape) - 1)) return Accept() if axis == supported_axes else Reject([52])
[docs] def supported_gridsample(args: ExprCheckInfo) -> bool: """ Accept if interpolation is 2D and mode is 'linear' and padding_mode is not 'reflection'. """ input_shape = args.input_shapes[0] # Find attrs of GridSample for dict in args.attrs: if 'method' in dict: break mode = dict['method'] padding_mode = dict['padding_mode'] if len(input_shape) != 4 or mode != 'bilinear' or padding_mode == 'reflection': return Reject([53]) else: return Accept()