Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] Implement RNN support #25755

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Ruturaj4
Copy link
Collaborator

@Ruturaj4 Ruturaj4 commented Jan 7, 2025

Created from: ROCm#171

@Ruturaj4
Copy link
Collaborator Author

Ruturaj4 commented Jan 7, 2025

@dfm and @superbobry could you please take a look?

@github-actions github-actions bot force-pushed the ci_rnn_final-upstream branch from 0b07837 to 36d037e Compare January 7, 2025 19:08
Copy link
Collaborator

@superbobry superbobry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dfm want to have a look as well?

@jax.default_matmul_precision("float32")
def test_lstm(self, batch_size: int, seq_len: int, input_size: int,
hidden_size: int, num_layers: int, bidirectional: bool):
# TODO(phawkins): Partially disable this on cudnn version per b/281071013
if (batch_size == 1 and seq_len == 4 and input_size == 1 and
if jtu.is_device_cuda() and (batch_size == 1 and seq_len == 4 and input_size == 1 and
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can just nuke this condition since JAX requires cuDNN >=9.1.

@@ -61,6 +65,7 @@ def test_lstm(self, batch_size: int, seq_len: int, input_size: int,
weights = rnn.init_lstm_weight(k4, input_size, hidden_size, num_layers,
bidirectional)
def f(weights, x, h_0, c_0):
weights = rnn.swap_lstm_gates(weights, input_size, hidden_size, num_layers, bidirectional)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How did the test pass without this call before?

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Jan 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants