Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebNN] Fix bug in SkipSimplifiedLayerNormalization #23236

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
options.set("epsilon", epsilon);

emscripten::val output = emscripten::val::undefined();
// SkipSimplifiedLayerNormalization's output: input_skip_bias_sum.
emscripten::val input_skip_bias_sum = emscripten::val::undefined();
if (op_type == "BatchNormalization") {
ORT_RETURN_IF_NOT(input_defs.size() == 5, "BatchNormalization requires five inputs.");
emscripten::val mean = model_builder.GetOperand(input_defs[3]->Name());
Expand Down Expand Up @@ -107,14 +105,31 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
| | | | | |
Y:2 axis B:epsilon A:X A:scale B:bias

If it is SkipSimplifiedLayerNormalization and its output input_skip_bias_sum exists,
If it is SkipSimplifiedLayerNormalization, X should be input_skip_bias_sum:
input_skip_bias_sum = X + skip + bias (if it exists)
*/

int32_t input_type;
ORT_RETURN_IF_NOT(GetType(*input_defs[0], input_type, logger), "Cannot get input type");
emscripten::val common_options = emscripten::val::object();

// If it is SkipSimplifiedLayerNormalization, add the skip, bias (if it exits) to the input.
Honry marked this conversation as resolved.
Show resolved Hide resolved
if (op_type == "SkipSimplifiedLayerNormalization") {
emscripten::val skip = model_builder.GetOperand(input_defs[1]->Name());
common_options.set("label", node.Name() + "_add_skip");
input = model_builder.GetBuilder().call<emscripten::val>("add", input, skip, common_options);
if (!bias.isUndefined()) {
common_options.set("label", node.Name() + "_add_skip_bias");
input = model_builder.GetBuilder().call<emscripten::val>("add", input, bias, common_options);
}

// Add SkipSimplifiedLayerNormalization's output input_skip_bias_sum if it exists.
// Now input equals to input_skip_bias_sum.
if (TensorExists(output_defs, 3)) {
model_builder.AddOperand(output_defs[3]->Name(), input);
}
}

// Pow
emscripten::val pow_constant = model_builder.CreateOrGetConstant<float>(input_type, 2);
common_options.set("label", node.Name() + "_pow");
Expand Down Expand Up @@ -151,19 +166,6 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
common_options.set("label", node.Name() + "_add_bias");
output = model_builder.GetBuilder().call<emscripten::val>("add", output, bias, common_options);
}

// SkipSimplifiedLayerNormalization's output input_skip_bias_sum is the sum of input, skip, and bias.
if (op_type == "SkipSimplifiedLayerNormalization" && TensorExists(output_defs, 3)) {
emscripten::val skip = model_builder.GetOperand(input_defs[1]->Name());
common_options.set("label", node.Name() + "_add_skip");
input_skip_bias_sum = model_builder.GetBuilder().call<emscripten::val>("add", input, skip, common_options);
if (!bias.isUndefined()) {
common_options.set("label", node.Name() + "_add_skip_bias");
input_skip_bias_sum = model_builder.GetBuilder().call<emscripten::val>(
"add", input_skip_bias_sum, bias, common_options);
}
model_builder.AddOperand(output_defs[3]->Name(), std::move(input_skip_bias_sum));
}
}
} else if (op_type == "InstanceNormalization") {
// WebNN spec only supports 4D input for instanceNormalization.
Expand Down
Loading