sima_utils.transformer.onnx_builder

Attributes

OnnxNode

Classes

OnnxBuilder

Helper class to build onnx model.

Module Contents

sima_utils.transformer.onnx_builder.OnnxNode
class sima_utils.transformer.onnx_builder.OnnxBuilder

Helper class to build onnx model.

IR_VERSION

IR version of the onnx model.

OPSET_ID

Operator set id of the onnx model.

onnx_file_name

File name of the onnx file.

get_param_func

A function that returns a parameter tensor with the provided parameter name.

check_param_func

A function that checks if a parameter tensor with the provided parameter name exists.

input_nodes

Input nodes of the onnx model.

output_nodes

Output nodes of the onnx model.

_initializer_map

A mapping from a name to an onnx initializer.

_node_map

A mapping from a name to an onnx node.

IR_VERSION: ClassVar[int] = 8
OPSET_ID: ClassVar[int] = 17
onnx_file_name: pathlib.Path
get_param_func: collections.abc.Callable[[str], numpy.ndarray] | None = None
check_param_func: collections.abc.Callable[[str], bool] | None = None
input_nodes: list[onnx.ValueInfoProto] = []
output_nodes: list[onnx.ValueInfoProto] = []
create_and_save_model(do_simplify: bool = True)

Creates and saves the model.

Parameters:

do_simplify – Set true to simplify the created onnx graph.

create_model() onnx.ModelProto

Creates the model and performs shape inference.

save_model(model: onnx.ModelProto)

Saves the model to a file.

Parameters:

model – The onnx model to be saved.

create_input_node(name: str, shape: collections.abc.Sequence[int], dtype: type = np.float32)

Creates an input node with the provided name, shape and data type.

create_output_node(name: str, shape: collections.abc.Sequence[int], dtype: type = np.float32)

Creates an output node with the provided name, shape and data type.

get_node_output_names(node: OnnxNode) list[str]

Gets a list of the output names of the given node.

get_node_output_name(node: OnnxNode) str

Gets the output name of the given node with only one output.

create_initializer(name: str, value: int | float | numpy.ndarray | None = None, reshape_str: str | None = None) OnnxNode | None

Creates an initializer with the name.

Parameters:
  • name – Initializer name.

  • value – Value of the initializer. If value is None, then look up the value using get_param_func; if get_param_func is None, then look up the value using the pre-defined file name.

  • reshape_str – A string to reshape the value.

Returns:

Return an initializer if a valid value is found. Otherwise, return None.

reshape_data(data: numpy.ndarray, reshape_str: str | None = None) numpy.ndarray
build_op(base_name: str, input_nodes: collections.abc.Sequence[OnnxNode], op_type: str, **kwargs) OnnxNode

Builds an ONNX node.

Parameters:
  • base_name – Base name of the operator. This is used to create the node and the initializer.

  • input_nodes – A list of input nodes.

  • op_type – Name of the operator type.

  • **kwargs – Operator attributes.

Returns:

Created ONNX node.

build_conv(base_name: str, input_node: OnnxNode, is_fc: bool = True, **kwargs) OnnxNode

Builds a convolution node.

Parameters:
  • base_name – Base name of the operator. This is used to create the node and the initializers.

  • input_node – The input node of the convolution node.

  • is_fc – Set True to indicate that the original operator is a fully-connected layer or a matrix multiplication where the weight needs to be reshaped to build the convolution.

  • **kwargs – Convolution attributes.

Returns:

Created convolution node.

build_split_and_concat(base_name: str, input_node: OnnxNode, num_splits: int, split_axis: int, concat_axis: int) OnnxNode

Builds nodes for a split-and-concat operation.

Parameters:
  • base_name – Base name of the operator. This is used to create the nodes.

  • input_node – The input node of the split node.

  • num_splits – Number of splits.

  • split_axis – The axis for the input node to be split.

  • concat_axis – The axis for the split nodes to be concatenated.

Returns:

Created split and concatenate nodes.

build_split_expand_concat(base_name: str, input_node: OnnxNode, num_splits: int, num_repeats: int, split_axis: int, concat_axis: int, concat_shape: tuple[int, Ellipsis]) OnnxNode

Builds nodes for a split-expand-concat operation.

This is to support Group Quary Attention (GQA), where the number of KV heads is less than the number of attention heads. The KV tensors out of a KV cache need to be repeated to match attention heads.

Parameters:
  • base_name – Base name of the operator. This is used to create the nodes.

  • input_node – The input node of the split node.

  • num_splits – Number of splits.

  • num_repeats – Number of repeats for each split.

  • split_axis – The axis for the input node to be split.

  • concat_axis – The axis for the split nodes to be concatenated.

Returns:

Created split, expand, and concatenate nodes.

build_layer_norm(base_name: str, input_node: OnnxNode, epsilon: float = 1e-05) OnnxNode

Builds nodes for layer norm operation.

Parameters:
  • base_name – Base name of the operator. This is used to create the nodes.

  • input_node – The input node of the split node.

Returns:

Created layer norm nodes.

build_rms_norm(base_name: str, input_node: OnnxNode, epsilon: float, weight_offset: float) OnnxNode

Builds nodes for RMS norm operation.

Parameters:
  • base_name – Base name of the operator. This is used to create the nodes.

  • input_node – The input node of the split node.

  • epsilon – Epsilon of the RMS normalization.

  • weight_offset – Offset to the weights.

Returns:

Created RMS norm nodes.

build_logit_softcapping(base_name: str, input_node: OnnxNode, scalar: float) OnnxNode

Build nodes for logit soft capping.

Logit soft capping is used in GEMMA2 to prevent overconfident predictions.

softcapping(x) = scalar * tanh(x/scalar)

= scalar * [2*sigmoid(2x/scalar)-1] = (2*scalar) * sigmoid(x * (2/scalar)) - scalar

operations: x - mul - sigmoid - mul - sub

build_activation(base_name: str, input_node: OnnxNode, act_type: str) OnnxNode

Build nodes for activation.

LLAMA uses β€œsilu” which uses sigmoid. GEMMA uses β€œgelu_pytorch_tanh” which uses Gaussian ERF with tanh approximation.

Because tanh(x) = 2*sigmoid(2x)-1, gelu can also be approximated by sigmoid.
gelu_tanh(x) = 0.5 * x * [1 + tanh(root(2/PI)*x*(1 + 0.044715 * x * x))]

= x * sigmoid(x*(A + B * x * x))

where A = 2 * root(2/PI), B = A * 0.044715

build_matmul_and_split_heads(base_name: str, input_node: OnnxNode, num_heads: int, seq_len: int, kv_len: int | None = None, post_matmul_scale: float = 1.0) list[OnnxNode]
build_merge_heads_and_matmul(base_name: str, input_nodes: list[OnnxNode], num_heads: int) OnnxNode
build_attention(base_name: str, input_nodes: list[OnnxNode], num_heads: int, head_dim: int, seq_len: int, kv_len: int | None = None, skip_kv_projs_and_split_head: bool = False, mask_node: OnnxNode | None = None, output_kv_projs: bool = False) list[OnnxNode]
build_encoder_decoder_mlp(base_name: str, input_node: OnnxNode, act_type: str) OnnxNode