sima_utils.transformer.onnx_builderο
Attributesο
Classesο
Helper class to build onnx model. |
Module Contentsο
- sima_utils.transformer.onnx_builder.OnnxNodeο
- class sima_utils.transformer.onnx_builder.OnnxBuilderο
Helper class to build onnx model.
- IR_VERSIONο
IR version of the onnx model.
- OPSET_IDο
Operator set id of the onnx model.
- onnx_file_nameο
File name of the onnx file.
- get_param_funcο
A function that returns a parameter tensor with the provided parameter name.
- check_param_funcο
A function that checks if a parameter tensor with the provided parameter name exists.
- input_nodesο
Input nodes of the onnx model.
- output_nodesο
Output nodes of the onnx model.
- _initializer_mapο
A mapping from a name to an onnx initializer.
- _node_mapο
A mapping from a name to an onnx node.
- IR_VERSION: ClassVar[int] = 8ο
- OPSET_ID: ClassVar[int] = 17ο
- onnx_file_name: pathlib.Pathο
- get_param_func: collections.abc.Callable[[str], numpy.ndarray] | None = Noneο
- check_param_func: collections.abc.Callable[[str], bool] | None = Noneο
- input_nodes: list[onnx.ValueInfoProto] = []ο
- output_nodes: list[onnx.ValueInfoProto] = []ο
- create_and_save_model(do_simplify: bool = True)ο
Creates and saves the model.
- Parameters:
do_simplify β Set true to simplify the created onnx graph.
- create_model() onnx.ModelProto ο
Creates the model and performs shape inference.
- save_model(model: onnx.ModelProto)ο
Saves the model to a file.
- Parameters:
model β The onnx model to be saved.
- create_input_node(name: str, shape: collections.abc.Sequence[int], dtype: type = np.float32)ο
Creates an input node with the provided name, shape and data type.
- create_output_node(name: str, shape: collections.abc.Sequence[int], dtype: type = np.float32)ο
Creates an output node with the provided name, shape and data type.
- get_node_output_names(node: OnnxNode) list[str] ο
Gets a list of the output names of the given node.
- get_node_output_name(node: OnnxNode) str ο
Gets the output name of the given node with only one output.
- create_initializer(name: str, value: int | float | numpy.ndarray | None = None, reshape_str: str | None = None) OnnxNode | None ο
Creates an initializer with the name.
- Parameters:
name β Initializer name.
value β Value of the initializer. If value is None, then look up the value using get_param_func; if get_param_func is None, then look up the value using the pre-defined file name.
reshape_str β A string to reshape the value.
- Returns:
Return an initializer if a valid value is found. Otherwise, return None.
- reshape_data(data: numpy.ndarray, reshape_str: str | None = None) numpy.ndarray ο
- build_op(base_name: str, input_nodes: collections.abc.Sequence[OnnxNode], op_type: str, **kwargs) OnnxNode ο
Builds an ONNX node.
- Parameters:
base_name β Base name of the operator. This is used to create the node and the initializer.
input_nodes β A list of input nodes.
op_type β Name of the operator type.
**kwargs β Operator attributes.
- Returns:
Created ONNX node.
- build_conv(base_name: str, input_node: OnnxNode, is_fc: bool = True, **kwargs) OnnxNode ο
Builds a convolution node.
- Parameters:
base_name β Base name of the operator. This is used to create the node and the initializers.
input_node β The input node of the convolution node.
is_fc β Set True to indicate that the original operator is a fully-connected layer or a matrix multiplication where the weight needs to be reshaped to build the convolution.
**kwargs β Convolution attributes.
- Returns:
Created convolution node.
- build_split_and_concat(base_name: str, input_node: OnnxNode, num_splits: int, split_axis: int, concat_axis: int) OnnxNode ο
Builds nodes for a split-and-concat operation.
- Parameters:
base_name β Base name of the operator. This is used to create the nodes.
input_node β The input node of the split node.
num_splits β Number of splits.
split_axis β The axis for the input node to be split.
concat_axis β The axis for the split nodes to be concatenated.
- Returns:
Created split and concatenate nodes.
- build_split_expand_concat(base_name: str, input_node: OnnxNode, num_splits: int, num_repeats: int, split_axis: int, concat_axis: int, concat_shape: tuple[int, Ellipsis]) OnnxNode ο
Builds nodes for a split-expand-concat operation.
This is to support Group Quary Attention (GQA), where the number of KV heads is less than the number of attention heads. The KV tensors out of a KV cache need to be repeated to match attention heads.
- Parameters:
base_name β Base name of the operator. This is used to create the nodes.
input_node β The input node of the split node.
num_splits β Number of splits.
num_repeats β Number of repeats for each split.
split_axis β The axis for the input node to be split.
concat_axis β The axis for the split nodes to be concatenated.
- Returns:
Created split, expand, and concatenate nodes.
- build_layer_norm(base_name: str, input_node: OnnxNode, epsilon: float = 1e-05) OnnxNode ο
Builds nodes for layer norm operation.
- Parameters:
base_name β Base name of the operator. This is used to create the nodes.
input_node β The input node of the split node.
- Returns:
Created layer norm nodes.
- build_rms_norm(base_name: str, input_node: OnnxNode, epsilon: float, weight_offset: float) OnnxNode ο
Builds nodes for RMS norm operation.
- Parameters:
base_name β Base name of the operator. This is used to create the nodes.
input_node β The input node of the split node.
epsilon β Epsilon of the RMS normalization.
weight_offset β Offset to the weights.
- Returns:
Created RMS norm nodes.
- build_logit_softcapping(base_name: str, input_node: OnnxNode, scalar: float) OnnxNode ο
Build nodes for logit soft capping.
Logit soft capping is used in GEMMA2 to prevent overconfident predictions.
- softcapping(x) = scalar * tanh(x/scalar)
= scalar * [2*sigmoid(2x/scalar)-1] = (2*scalar) * sigmoid(x * (2/scalar)) - scalar
operations: x - mul - sigmoid - mul - sub
- build_activation(base_name: str, input_node: OnnxNode, act_type: str) OnnxNode ο
Build nodes for activation.
LLAMA uses βsiluβ which uses sigmoid. GEMMA uses βgelu_pytorch_tanhβ which uses Gaussian ERF with tanh approximation.
- Because tanh(x) = 2*sigmoid(2x)-1, gelu can also be approximated by sigmoid.
- gelu_tanh(x) = 0.5 * x * [1 + tanh(root(2/PI)*x*(1 + 0.044715 * x * x))]
= x * sigmoid(x*(A + B * x * x))
where A = 2 * root(2/PI), B = A * 0.044715
- build_matmul_and_split_heads(base_name: str, input_node: OnnxNode, num_heads: int, seq_len: int, kv_len: int | None = None, post_matmul_scale: float = 1.0) list[OnnxNode] ο
- build_merge_heads_and_matmul(base_name: str, input_nodes: list[OnnxNode], num_heads: int) OnnxNode ο
- build_attention(base_name: str, input_nodes: list[OnnxNode], num_heads: int, head_dim: int, seq_len: int, kv_len: int | None = None, skip_kv_projs_and_split_head: bool = False, mask_node: OnnxNode | None = None, output_kv_projs: bool = False) list[OnnxNode] ο
- build_encoder_decoder_mlp(base_name: str, input_node: OnnxNode, act_type: str) OnnxNode ο