Skip to content

Commit

Permalink
Add ReplaceErfWithTanh to GraphSurgeries (#1521)
Browse files Browse the repository at this point in the history
## Describe your changes

Add ReplaceErfWithTanh to GraphSurgeries

## Checklist before requesting a review
- [ ] Add unit tests for this change.
- [ ] Make sure all tests can pass.
- [ ] Update documents if necessary.
- [ ] Lint and apply fixes to your code by running `lintrunner -a`
- [ ] Is this a user-facing change? If yes, give a description of this
change to be included in the release notes.
- [ ] Is this PR including examples changes? If yes, please remember to
update [example
documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md)
in a follow-up PR.

## (Optional) Issue link
  • Loading branch information
xiaoyu-work authored Dec 12, 2024
1 parent 9191ba6 commit dc67008
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 1 deletion.
57 changes: 57 additions & 0 deletions docs/source/how-to/configure-workflows/onnx-graph-surgeon.md
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,63 @@ graph {
}
```

### `ReplaceErfWithTanh`

#### Description

Replaces `Erf` nodes in the ONNX model with an equivalent computation using `Tanh`. The replacement involves scaling the input and applying the `Tanh` function to produce a result that approximates the `Erf` behavior.

#### Example

Initial ONNX model graph:

```
graph {
input: "input"
output: "erf_output"
node {
op_type: "Erf"
input: ["input"]
output: ["erf_output"]
}
}
```

After applying:

```json
{
"type": "GraphSurgeries",
"surgeries": [
{
"surgeon": "ReplaceErfWithTanh"
}
]
}
```

Transformed ONNX model graph:

```
graph {
input: "input"
initializer: "scale_0" (FLOAT, value: 1.203)
node {
op_type: "Mul"
input: ["input", "scale_0"]
output: ["mul_0"]
name: "Sub_Mul_0"
}
node {
op_type: "Tanh"
input: ["mul_0"]
output: ["erf_output"]
name: "Sub_Tanh_0"
}
output: "erf_output"
}
```

### `ZeroOutInput`

#### Description
Expand Down
82 changes: 81 additions & 1 deletion olive/passes/onnx/graph_surgeries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

# ruff: noqa: RUF012

import inspect
import logging
from typing import Any, ClassVar, Dict, List, Type

import numpy as np
import onnx
from onnx import ModelProto
from onnx import ModelProto, TensorProto
from onnx.helper import make_tensor

from olive.hardware.accelerator import AcceleratorSpec
Expand Down Expand Up @@ -162,6 +165,83 @@ def __call__(self, model: ModelProto):
return model


class ReplaceErfWithTanh(Surgeon):

DTYPE_MAP = {
TensorProto.FLOAT: np.float32,
TensorProto.FLOAT16: np.float16,
TensorProto.DOUBLE: np.float64,
TensorProto.BFLOAT16: np.uint16,
TensorProto.INT8: np.int8,
TensorProto.INT16: np.int16,
TensorProto.INT32: np.int32,
TensorProto.INT64: np.int64,
TensorProto.UINT8: np.uint8,
TensorProto.UINT16: np.uint16,
TensorProto.UINT32: np.uint32,
TensorProto.UINT64: np.uint64,
}

def __init__(self):
pass

def __call__(self, model: ModelProto):
idx = 0
while idx < len(model.graph.node):
node = model.graph.node[idx]
if node.op_type == "Erf":
inputs = node.input
outputs = node.output
input_dtype = self._get_input_dtype(model, inputs[0])
np_type = self.DTYPE_MAP.get(input_dtype)
if np_type is None:
logger.warning(
"Unsupported dtype %s for node %s. Skip replacing Erf with Tanh.", input_dtype, node.name
)
idx += 1
continue

model.graph.node.remove(node)
name = f"scale_{idx}"
output_scale = f"mul_{idx}"

# scaling constant for tanh
value = np.array(605 / 503, dtype=np_type)
scale = onnx.helper.make_tensor(
name=name,
data_type=input_dtype,
dims=value.shape,
vals=value.flatten().tolist(),
)
model.graph.initializer.append(scale)

mul_node = onnx.helper.make_node(
"Mul", inputs=[inputs[0], name], outputs=[output_scale], name=f"Sub_Mul_{idx}"
)
tanh_node = onnx.helper.make_node(
"Tanh", inputs=[output_scale], outputs=outputs, name=f"Sub_Tanh_{idx}"
)

model.graph.node.insert(idx, mul_node)
model.graph.node.insert(idx + 1, tanh_node)
idx += 2
else:
idx += 1
return model

def _get_input_dtype(self, model, name):
for inp in model.graph.input:
if inp.name == name:
return inp.type.tensor_type.elem_type
for vi in model.graph.value_info:
if vi.name == name:
return vi.type.tensor_type.elem_type
for init in model.graph.initializer:
if init.name == name:
return init.data_type
raise ValueError(f"Cannot find dtype for {name}")


class ZeroOutInput(Surgeon):
def __init__(self, node_name, input_idx):
self.node_name = node_name
Expand Down
30 changes: 30 additions & 0 deletions test/unit_test/passes/onnx/test_graph_surgeries.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,36 @@ def test_reorder_inputs(tmp_path):
assert [graph_input.name for graph_input in model_def.graph.input] == ["input2", "input1"]


def test_replace_erf_with_tanh(tmp_path):
# setup
model_path = tmp_path / "model.onnx"
input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 3])
output_tensor = helper.make_tensor_value_info("erf_output", TensorProto.FLOAT, [1, 3])
erf_node = helper.make_node("Erf", inputs=["input"], outputs=["erf_output"], name="ErfNode")
graph_def = helper.make_graph(nodes=[erf_node], name="ErfTestGraph", inputs=[input_tensor], outputs=[output_tensor])
model = helper.make_model(graph_def, producer_name="onnx-example")
onnx.save(model, model_path)
p = create_pass_from_dict(
GraphSurgeries,
{"surgeries": [{"surgeon": "ReplaceErfWithTanh"}]},
disable_search=True,
)

# execute
onnx_model = p.run(ONNXModelHandler(model_path=str(model_path)), str(tmp_path / "onnx"))

# assert
model_def = onnx_model.load_model()
tanh_node = next(node for node in model_def.graph.node if node.op_type == "Tanh")
mul_node = next(node for node in model_def.graph.node if node.op_type == "Mul")

scale_initializer = next(init for init in model_def.graph.initializer if init.name == mul_node.input[1])
scale_value = np.array(scale_initializer.float_data, dtype=np.float32)
assert np.isclose(scale_value, 605 / 503, atol=1e-6), "Scale value mismatch"
assert tanh_node.input[0] == mul_node.output[0], "Tanh input should match Mul output"
assert tanh_node.output[0] == "erf_output", "Tanh output should replace Erf output"


def test_zero_out_input(tmp_path):
# setup
input_model_path = tmp_path / "model.onnx"
Expand Down

0 comments on commit dc67008

Please sign in to comment.