diff --git a/docs/source/how-to/configure-workflows/onnx-graph-surgeon.md b/docs/source/how-to/configure-workflows/onnx-graph-surgeon.md index d9b3cb565..2cac249c4 100644 --- a/docs/source/how-to/configure-workflows/onnx-graph-surgeon.md +++ b/docs/source/how-to/configure-workflows/onnx-graph-surgeon.md @@ -385,6 +385,63 @@ graph { } ``` +### `ReplaceErfWithTanh` + +#### Description + +Replaces `Erf` nodes in the ONNX model with an equivalent computation using `Tanh`. The replacement involves scaling the input and applying the `Tanh` function to produce a result that approximates the `Erf` behavior. + +#### Example + +Initial ONNX model graph: + +``` +graph { + input: "input" + output: "erf_output" + node { + op_type: "Erf" + input: ["input"] + output: ["erf_output"] + } +} +``` + +After applying: + +```json +{ + "type": "GraphSurgeries", + "surgeries": [ + { + "surgeon": "ReplaceErfWithTanh" + } + ] +} +``` + +Transformed ONNX model graph: + +``` +graph { + input: "input" + initializer: "scale_0" (FLOAT, value: 1.203) + node { + op_type: "Mul" + input: ["input", "scale_0"] + output: ["mul_0"] + name: "Sub_Mul_0" + } + node { + op_type: "Tanh" + input: ["mul_0"] + output: ["erf_output"] + name: "Sub_Tanh_0" + } + output: "erf_output" +} +``` + ### `ZeroOutInput` #### Description diff --git a/olive/passes/onnx/graph_surgeries.py b/olive/passes/onnx/graph_surgeries.py index bc697f9e8..c68e9f10d 100644 --- a/olive/passes/onnx/graph_surgeries.py +++ b/olive/passes/onnx/graph_surgeries.py @@ -2,13 +2,16 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- + +# ruff: noqa: RUF012 + import inspect import logging from typing import Any, ClassVar, Dict, List, Type import numpy as np import onnx -from onnx import ModelProto +from onnx import ModelProto, TensorProto from onnx.helper import make_tensor from olive.hardware.accelerator import AcceleratorSpec @@ -162,6 +165,83 @@ def __call__(self, model: ModelProto): return model +class ReplaceErfWithTanh(Surgeon): + + DTYPE_MAP = { + TensorProto.FLOAT: np.float32, + TensorProto.FLOAT16: np.float16, + TensorProto.DOUBLE: np.float64, + TensorProto.BFLOAT16: np.uint16, + TensorProto.INT8: np.int8, + TensorProto.INT16: np.int16, + TensorProto.INT32: np.int32, + TensorProto.INT64: np.int64, + TensorProto.UINT8: np.uint8, + TensorProto.UINT16: np.uint16, + TensorProto.UINT32: np.uint32, + TensorProto.UINT64: np.uint64, + } + + def __init__(self): + pass + + def __call__(self, model: ModelProto): + idx = 0 + while idx < len(model.graph.node): + node = model.graph.node[idx] + if node.op_type == "Erf": + inputs = node.input + outputs = node.output + input_dtype = self._get_input_dtype(model, inputs[0]) + np_type = self.DTYPE_MAP.get(input_dtype) + if np_type is None: + logger.warning( + "Unsupported dtype %s for node %s. Skip replacing Erf with Tanh.", input_dtype, node.name + ) + idx += 1 + continue + + model.graph.node.remove(node) + name = f"scale_{idx}" + output_scale = f"mul_{idx}" + + # scaling constant for tanh + value = np.array(605 / 503, dtype=np_type) + scale = onnx.helper.make_tensor( + name=name, + data_type=input_dtype, + dims=value.shape, + vals=value.flatten().tolist(), + ) + model.graph.initializer.append(scale) + + mul_node = onnx.helper.make_node( + "Mul", inputs=[inputs[0], name], outputs=[output_scale], name=f"Sub_Mul_{idx}" + ) + tanh_node = onnx.helper.make_node( + "Tanh", inputs=[output_scale], outputs=outputs, name=f"Sub_Tanh_{idx}" + ) + + model.graph.node.insert(idx, mul_node) + model.graph.node.insert(idx + 1, tanh_node) + idx += 2 + else: + idx += 1 + return model + + def _get_input_dtype(self, model, name): + for inp in model.graph.input: + if inp.name == name: + return inp.type.tensor_type.elem_type + for vi in model.graph.value_info: + if vi.name == name: + return vi.type.tensor_type.elem_type + for init in model.graph.initializer: + if init.name == name: + return init.data_type + raise ValueError(f"Cannot find dtype for {name}") + + class ZeroOutInput(Surgeon): def __init__(self, node_name, input_idx): self.node_name = node_name diff --git a/test/unit_test/passes/onnx/test_graph_surgeries.py b/test/unit_test/passes/onnx/test_graph_surgeries.py index 77384b71a..d2db03fbf 100644 --- a/test/unit_test/passes/onnx/test_graph_surgeries.py +++ b/test/unit_test/passes/onnx/test_graph_surgeries.py @@ -195,6 +195,36 @@ def test_reorder_inputs(tmp_path): assert [graph_input.name for graph_input in model_def.graph.input] == ["input2", "input1"] +def test_replace_erf_with_tanh(tmp_path): + # setup + model_path = tmp_path / "model.onnx" + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 3]) + output_tensor = helper.make_tensor_value_info("erf_output", TensorProto.FLOAT, [1, 3]) + erf_node = helper.make_node("Erf", inputs=["input"], outputs=["erf_output"], name="ErfNode") + graph_def = helper.make_graph(nodes=[erf_node], name="ErfTestGraph", inputs=[input_tensor], outputs=[output_tensor]) + model = helper.make_model(graph_def, producer_name="onnx-example") + onnx.save(model, model_path) + p = create_pass_from_dict( + GraphSurgeries, + {"surgeries": [{"surgeon": "ReplaceErfWithTanh"}]}, + disable_search=True, + ) + + # execute + onnx_model = p.run(ONNXModelHandler(model_path=str(model_path)), str(tmp_path / "onnx")) + + # assert + model_def = onnx_model.load_model() + tanh_node = next(node for node in model_def.graph.node if node.op_type == "Tanh") + mul_node = next(node for node in model_def.graph.node if node.op_type == "Mul") + + scale_initializer = next(init for init in model_def.graph.initializer if init.name == mul_node.input[1]) + scale_value = np.array(scale_initializer.float_data, dtype=np.float32) + assert np.isclose(scale_value, 605 / 503, atol=1e-6), "Scale value mismatch" + assert tanh_node.input[0] == mul_node.output[0], "Tanh input should match Mul output" + assert tanh_node.output[0] == "erf_output", "Tanh output should replace Erf output" + + def test_zero_out_input(tmp_path): # setup input_model_path = tmp_path / "model.onnx"