Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rotary embedding fusion rule (part 1) #1981

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open

Conversation

gramalingam
Copy link
Collaborator

@gramalingam gramalingam commented Dec 18, 2024

Initial version of fusion for rotary embedding.

Limitations: currently addresses only non-interleaved and full rotation.

Other:

  • Add support for rewriting rules where the matched nodes are not removed. Useful in cases where matched nodes include some shared nodes.
  • Add optimization to eliminate redundant Reshape (helps simplify pattern).

Copy link

codecov bot commented Dec 18, 2024

❌ 51 Tests Failed:

Tests completed Failed Passed Skipped
12082 51 12031 2446
View the full list of 3 ❄️ flaky tests
tests.eager_mode_test.TestEagerModeArguments_0_reference_runtime::test_function_input_and_attribute_by_kwargs_out_of_order

Flake rate in main: 39.26% (Passed 12653 times, Failed 8180 times)

Stack Traces | 0.002s run time
..../test_torch_nightly/lib/python3.11.../reference/ops/_op.py:91: in run
    res = self._run(x, y)
..../test_torch_nightly/lib/python3.11.../reference/ops/_op.py:139: in _run
    res = (convert_from_ml_dtypes(res[0]),)
..../test_torch_nightly/lib/python3.11.../onnx/reference/custom_element_types.py:50: in convert_from_ml_dtypes
    return array.view(dtype=dtype)
E   ValueError: Changing the dtype of a 0d array is only supported if the itemsize is unchanged

The above exception was the direct cause of the following exception:
tests/eager_mode_test.py:115: in test_function_input_and_attribute_by_kwargs_out_of_order
    self.assertEqual(add_with_alpha(alpha=3.0, other=2.0, this=1.0), 7.0)
onnxscript/values.py:576: in __call__
    return evaluator.default().eval_function(self, args, kwargs)
onnxscript/evaluator.py:307: in eval_function
    result = function.function(*adapted_args, **adapted_kwargs)
tests/eager_mode_test.py:59: in add_with_alpha
    other = op.Mul(other, alpha)
.../onnx_opset/_impl/opset14.py:696: in Mul
    return op(*self._prepare_inputs(schema, A, B))
onnxscript/values.py:304: in __call__
    return evaluator.default().eval(schema, args, kwargs)
onnxscript/evaluator.py:194: in eval
    outputs = self._eval(schema, inputs, attributes, closure)
onnxscript/evaluator.py:526: in _eval
    result = session.run(None, session_run_input)
..../test_torch_nightly/lib/python3.11.../onnx/reference/reference_evaluator.py:593: in run
    outputs = node.run(*inputs, **linked_attributes)
..../test_torch_nightly/lib/python3.11.../reference/ops/_op.py:114: in run
    res = OpRunBinary.run(self, x, y)
..../test_torch_nightly/lib/python3.11.../reference/ops/_op.py:93: in run
    raise TypeError(
E   TypeError: Issues with types <class 'numpy.ndarray'>, <class 'numpy.ndarray'> (binary operator 'Mul').
tests.eager_mode_test.TestEagerModeArguments_0_reference_runtime::test_function_all_input_by_kwargs

Flake rate in main: 39.26% (Passed 12653 times, Failed 8180 times)

Stack Traces | 0.002s run time
..../test_torch_nightly/lib/python3.11.../reference/ops/_op.py:91: in run
    res = self._run(x, y)
..../test_torch_nightly/lib/python3.11.../reference/ops/_op.py:139: in _run
    res = (convert_from_ml_dtypes(res[0]),)
..../test_torch_nightly/lib/python3.11.../onnx/reference/custom_element_types.py:50: in convert_from_ml_dtypes
    return array.view(dtype=dtype)
E   ValueError: Changing the dtype of a 0d array is only supported if the itemsize is unchanged

The above exception was the direct cause of the following exception:
tests/eager_mode_test.py:109: in test_function_all_input_by_kwargs
    self.assertEqual(add_with_alpha(this=1.0, other=2.0), 3.0)
onnxscript/values.py:576: in __call__
    return evaluator.default().eval_function(self, args, kwargs)
onnxscript/evaluator.py:307: in eval_function
    result = function.function(*adapted_args, **adapted_kwargs)
tests/eager_mode_test.py:59: in add_with_alpha
    other = op.Mul(other, alpha)
.../onnx_opset/_impl/opset14.py:696: in Mul
    return op(*self._prepare_inputs(schema, A, B))
onnxscript/values.py:304: in __call__
    return evaluator.default().eval(schema, args, kwargs)
onnxscript/evaluator.py:194: in eval
    outputs = self._eval(schema, inputs, attributes, closure)
onnxscript/evaluator.py:526: in _eval
    result = session.run(None, session_run_input)
..../test_torch_nightly/lib/python3.11.../onnx/reference/reference_evaluator.py:593: in run
    outputs = node.run(*inputs, **linked_attributes)
..../test_torch_nightly/lib/python3.11.../reference/ops/_op.py:114: in run
    res = OpRunBinary.run(self, x, y)
..../test_torch_nightly/lib/python3.11.../reference/ops/_op.py:93: in run
    raise TypeError(
E   TypeError: Issues with types <class 'numpy.ndarray'>, <class 'numpy.ndarray'> (binary operator 'Mul').
tests.eager_mode_test.TestEagerModeArguments_0_reference_runtime::test_function_attribute_by_positional_args

Flake rate in main: 39.26% (Passed 12653 times, Failed 8180 times)

Stack Traces | 0.002s run time
..../test_torch_nightly/lib/python3.11.../reference/ops/_op.py:91: in run
    res = self._run(x, y)
..../test_torch_nightly/lib/python3.11.../reference/ops/_op.py:139: in _run
    res = (convert_from_ml_dtypes(res[0]),)
..../test_torch_nightly/lib/python3.11.../onnx/reference/custom_element_types.py:50: in convert_from_ml_dtypes
    return array.view(dtype=dtype)
E   ValueError: Changing the dtype of a 0d array is only supported if the itemsize is unchanged

The above exception was the direct cause of the following exception:
tests/eager_mode_test.py:112: in test_function_attribute_by_positional_args
    self.assertEqual(add_with_alpha(1.0, 2.0, 3.0), 7.0)
onnxscript/values.py:576: in __call__
    return evaluator.default().eval_function(self, args, kwargs)
onnxscript/evaluator.py:307: in eval_function
    result = function.function(*adapted_args, **adapted_kwargs)
tests/eager_mode_test.py:59: in add_with_alpha
    other = op.Mul(other, alpha)
.../onnx_opset/_impl/opset14.py:696: in Mul
    return op(*self._prepare_inputs(schema, A, B))
onnxscript/values.py:304: in __call__
    return evaluator.default().eval(schema, args, kwargs)
onnxscript/evaluator.py:194: in eval
    outputs = self._eval(schema, inputs, attributes, closure)
onnxscript/evaluator.py:526: in _eval
    result = session.run(None, session_run_input)
..../test_torch_nightly/lib/python3.11.../onnx/reference/reference_evaluator.py:593: in run
    outputs = node.run(*inputs, **linked_attributes)
..../test_torch_nightly/lib/python3.11.../reference/ops/_op.py:114: in run
    res = OpRunBinary.run(self, x, y)
..../test_torch_nightly/lib/python3.11.../reference/ops/_op.py:93: in run
    raise TypeError(
E   TypeError: Issues with types <class 'numpy.ndarray'>, <class 'numpy.ndarray'> (binary operator 'Mul').

To view more test analytics, go to the Test Analytics Dashboard
📢 Thoughts on this report? Let us know!

onnxscript/rewriter/_ir_utils.py Fixed Show fixed Hide fixed
onnxscript/rewriter/_ir_utils.py Fixed Show fixed Hide fixed
onnxscript/rewriter/_ir_utils.py Fixed Show fixed Hide fixed
onnxscript/rewriter/_ir_utils.py Fixed Show fixed Hide fixed
onnxscript/rewriter/_ir_utils.py Fixed Show fixed Hide fixed
@gramalingam gramalingam changed the title Add rotary embedding fusion rule (part 1) [Draft - WIP] Add rotary embedding fusion rule (part 1) Dec 20, 2024
onnxscript/rewriter/_ir_utils.py Fixed Show fixed Hide fixed
onnxscript/rewriter/_ir_utils.py Fixed Show fixed Hide fixed
@gramalingam gramalingam changed the title [Draft - WIP] Add rotary embedding fusion rule (part 1) Add rotary embedding fusion rule (part 1) Dec 23, 2024
# Slice(input, starts, ends, axes, steps)
x1 = op.Slice(x, start1, end1, [3], [1])
x2 = op.Slice(x, start2, end2, [3], [1])
minus_x2 = op.Neg(x2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although this logic is correct and makes sense, this doesn't match the function logic in the op definition. Is it correct to assume that the pattern logic should mimic the onnx function in the op schema?

Currently in the op schema, this pattern would look like after x1, x2 (which uses split instead of slice for non-interleaved case):

real = cos * x1 - sin * x2
imag = sin * x1 + cos * x2
rotated_x = op.Concat(real, imag)

So the concat happens after the multiplication

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not have to match the function logic in the op definition. But it has to match the function graph produced by the ONNX exporter from the logic defined in the source (eg., the transformers implementation).

But what we have to guarantee or ensure is that replacing this logic by the pattern in rewrite is fine: that they will both produce the same values.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specifically, it is more important to match the source logic like this transformer code

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Putting all this together, there are 3 parts to these rewrite-rules:

  • the pattern should typically be aligned with the subgraph pattern we see in the ONNX graphs produced by the exporter (which itself depends on the source pytorch code).
  • the rewrite part is aligned with the (fused) op definition (existing in ORT or being introduced to ONNX).
  • the check condition has to be strong enough to guarantee that the replacement is sound. So, that we can be sure we will produce the same outputs with or without the optimization.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shubhambhokare1 : please let me know if you have any further comments on this PR. It would be good to merge it in, so that I can make progress on a bunch of other changes (adding more fusions) ... preferable to avoid merge-conflicts with several independent, overlapping, PRs

return x * cos + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin

def check(self, op, x, start1, end1, start2, end2, **_):
# x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to schema, x can be a 3D tensor as well. And num_heads are necessary to be known in cases with 3D tensor.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That means the optimization is safe and correct (in this regard). To generalize and allow a 3D also here, we would need to guarantee that the entire-fusion is guaranteed to be semantically correct ... it is not enough to know that the RotaryEmbedding op permits 3D inputs.

What do you think about the correctness of this fusion optimization? Do you think it is fine to generalize and allow 3D here?

def rewrite(self, op, x, cos, sin, **_):
num_heads = x.shape[1]
return op.RotaryEmbedding(
x, cos, sin, interleaved=0, num_heads=num_heads, _domain="ai.onnxruntime.fusion"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious about the domain here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. I am currently using these to split the fusion optimization into multiple stages. We may need to clean this up finally. For now, we also need to target the existing RotaryEmbedding op in ORT (which is what we can test against also). Eventually, we can target the new proposed RotaryEmbedding op ... so we may also need to support some variations in the fusion optimization (depending on target ORT/ONNX versions).

@gramalingam gramalingam enabled auto-merge (squash) January 2, 2025 20:57
Comment on lines 3 to +15
from __future__ import annotations

from onnxscript.rewriter.onnxruntime.xformers.cos_sin_cache import fuse_cos_sin_cache
from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding
from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import fuse_normalization

__all__ = [
"fuse_rms_normalization",
"fuse_normalization",
"fuse_rotary_embedding",
"fuse_cos_sin_cache",
]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from __future__ import annotations
from onnxscript.rewriter.onnxruntime.xformers.cos_sin_cache import fuse_cos_sin_cache
from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding
from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import fuse_normalization
__all__ = [
"fuse_rms_normalization",
"fuse_normalization",
"fuse_rotary_embedding",
"fuse_cos_sin_cache",
]
from __future__ import annotations
__all__ = [
"fuse_rms_normalization",
"fuse_normalization",
"fuse_rotary_embedding",
"fuse_cos_sin_cache",
]
from onnxscript.rewriter.onnxruntime.xformers.cos_sin_cache import fuse_cos_sin_cache
from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding
from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import fuse_normalization

nit: all should be right after future imports

_domain="ai.onnxruntime.fusion",
)

def check(self, context, inv_freq, position_ids, **_):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def check(self, context, inv_freq, position_ids, **_):
def check(self, context, inv_freq, position_ids, **_) -> bool:

@@ -1197,6 +1201,7 @@ def match(
graph_or_function: ir.Graph | ir.Function,
node: ir.Node,
verbose: int = 0,
remove_nodes: bool = True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Helpful to document new args?

Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some nits, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Development

Successfully merging this pull request may close these issues.

3 participants