-
Notifications
You must be signed in to change notification settings - Fork 57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add rotary embedding fusion rule (part 1) #1981
base: main
Are you sure you want to change the base?
Conversation
❌ 51 Tests Failed:
View the full list of 3 ❄️ flaky tests
To view more test analytics, go to the Test Analytics Dashboard |
# Slice(input, starts, ends, axes, steps) | ||
x1 = op.Slice(x, start1, end1, [3], [1]) | ||
x2 = op.Slice(x, start2, end2, [3], [1]) | ||
minus_x2 = op.Neg(x2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although this logic is correct and makes sense, this doesn't match the function logic in the op definition. Is it correct to assume that the pattern logic should mimic the onnx function in the op schema?
Currently in the op schema, this pattern would look like after x1, x2 (which uses split instead of slice for non-interleaved case):
real = cos * x1 - sin * x2
imag = sin * x1 + cos * x2
rotated_x = op.Concat(real, imag)
So the concat happens after the multiplication
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does not have to match the function logic in the op definition. But it has to match the function graph produced by the ONNX exporter from the logic defined in the source (eg., the transformers implementation).
But what we have to guarantee or ensure is that replacing this logic by the pattern in rewrite
is fine: that they will both produce the same values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Specifically, it is more important to match the source logic like this transformer code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Putting all this together, there are 3 parts to these rewrite-rules:
- the
pattern
should typically be aligned with the subgraph pattern we see in the ONNX graphs produced by the exporter (which itself depends on the source pytorch code). - the
rewrite
part is aligned with the (fused) op definition (existing in ORT or being introduced to ONNX). - the
check
condition has to be strong enough to guarantee that the replacement is sound. So, that we can be sure we will produce the same outputs with or without the optimization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@shubhambhokare1 : please let me know if you have any further comments on this PR. It would be good to merge it in, so that I can make progress on a bunch of other changes (adding more fusions) ... preferable to avoid merge-conflicts with several independent, overlapping, PRs
return x * cos + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin | ||
|
||
def check(self, op, x, start1, end1, start2, end2, **_): | ||
# x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to schema, x can be a 3D tensor as well. And num_heads are necessary to be known in cases with 3D tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That means the optimization is safe and correct (in this regard). To generalize and allow a 3D also here, we would need to guarantee that the entire-fusion is guaranteed to be semantically correct ... it is not enough to know that the RotaryEmbedding op permits 3D inputs.
What do you think about the correctness of this fusion optimization? Do you think it is fine to generalize and allow 3D here?
def rewrite(self, op, x, cos, sin, **_): | ||
num_heads = x.shape[1] | ||
return op.RotaryEmbedding( | ||
x, cos, sin, interleaved=0, num_heads=num_heads, _domain="ai.onnxruntime.fusion" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious about the domain here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question. I am currently using these to split the fusion optimization into multiple stages. We may need to clean this up finally. For now, we also need to target the existing RotaryEmbedding op in ORT (which is what we can test against also). Eventually, we can target the new proposed RotaryEmbedding op ... so we may also need to support some variations in the fusion optimization (depending on target ORT/ONNX versions).
from __future__ import annotations | ||
|
||
from onnxscript.rewriter.onnxruntime.xformers.cos_sin_cache import fuse_cos_sin_cache | ||
from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization | ||
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding | ||
from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import fuse_normalization | ||
|
||
__all__ = [ | ||
"fuse_rms_normalization", | ||
"fuse_normalization", | ||
"fuse_rotary_embedding", | ||
"fuse_cos_sin_cache", | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from __future__ import annotations | |
from onnxscript.rewriter.onnxruntime.xformers.cos_sin_cache import fuse_cos_sin_cache | |
from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization | |
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding | |
from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import fuse_normalization | |
__all__ = [ | |
"fuse_rms_normalization", | |
"fuse_normalization", | |
"fuse_rotary_embedding", | |
"fuse_cos_sin_cache", | |
] | |
from __future__ import annotations | |
__all__ = [ | |
"fuse_rms_normalization", | |
"fuse_normalization", | |
"fuse_rotary_embedding", | |
"fuse_cos_sin_cache", | |
] | |
from onnxscript.rewriter.onnxruntime.xformers.cos_sin_cache import fuse_cos_sin_cache | |
from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization | |
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding | |
from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import fuse_normalization |
nit: all should be right after future imports
_domain="ai.onnxruntime.fusion", | ||
) | ||
|
||
def check(self, context, inv_freq, position_ids, **_): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def check(self, context, inv_freq, position_ids, **_): | |
def check(self, context, inv_freq, position_ids, **_) -> bool: |
@@ -1197,6 +1201,7 @@ def match( | |||
graph_or_function: ir.Graph | ir.Function, | |||
node: ir.Node, | |||
verbose: int = 0, | |||
remove_nodes: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Helpful to document new args?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some nits, thanks!
Initial version of fusion for rotary embedding.
Limitations: currently addresses only non-interleaved and full rotation.
Other: