Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

onnxruntime-genai generation speed very slow on int4 #1098

Open
tarekziade opened this issue Nov 23, 2024 · 7 comments
Open

onnxruntime-genai generation speed very slow on int4 #1098

tarekziade opened this issue Nov 23, 2024 · 7 comments

Comments

@tarekziade
Copy link
Contributor

tarekziade commented Nov 23, 2024

I have built a small example using the python binding here https://github.com/tarekziade/onnxruntime-test/blob/main/run.py
to measure the inference speed on my Apple M1 and on a windows 11 box, using Qwen 2.5 0.5B instruct

to prepare the model I used the cpu provider and int4/fp16/fp32 precisions:

python3 -m onnxruntime_genai.models.builder -m "Qwen/Qwen2.5-0.5B-Instruct" -o qwen -p int4 -e cpu
python3 -m onnxruntime_genai.models.builder -m "Qwen/Qwen2.5-0.5B-Instruct" -o qwen -p fp32 -e cpu
python3 -m onnxruntime_genai.models.builder -m "Qwen/Qwen2.5-0.5B-Instruct" -o qwen -p fp16 -e cpu

And compared the execution times with llama-cli using a GGUF of the same model using q4_0

Apple M1
Windows 11

One apple, the int4 precision is extremely slow on and fp16 failed on both platforms with

onnxruntime_genai.onnxruntime_genai.OrtException: 
Non-zero status code returned while running Cast node.
Name:'InsertedPrecisionFreeCast_/model/layers.1/attn/v_proj/repeat_kv/Reshape_4/output_0' Status 
Message: /Users/runner/work/1/s/onnxruntime/core/framework/op_kernel.cc:83 virtual OrtValue *onnxruntime::OpKernelContext::OutputMLValue(int, const onnxruntime::TensorShape &) status.IsOK() was false. 
Shape mismatch attempting to re-use buffer. {1,1,896} != {1,248,896}. 
Validate usage of dim_value (values should be > 0) and dim_param (all values with the same string should equate to the same size) in shapes in the model.

I was wondering if I did something wrong? I was also wondering if int8 precision is an option. looks like onnxruntime_genai.models.builder can use some int8 quantizations using the int4 mode but I am not entirely clear about this

@tarekziade tarekziade changed the title onnxruntime-genai generation speed onnxruntime-genai generation speed very slow on int4 Nov 23, 2024
@elephantpanda
Copy link

Your graph says "tokens per second" not "execution time".

Your graph says int4 has does the most tokens per second.

So your graphs seems to be saying the opposite of what you are saying unless you labelled the axis wrong? 😕

@tarekziade
Copy link
Contributor Author

tarekziade commented Nov 24, 2024

Your graph says "tokens per second" not "execution time".

Yes. That's a way to measure execution time -- or at least "speed" :)

Your graph says int4 has does the most tokens per second.

Correct, for llama-cli it's the highest, 140 tokens/s
For onnx, it's the lowest, 4.85 tokens/s

results in JSON : https://github.com/tarekziade/onnxruntime-test/blob/main/results.json

So your graphs seems to be saying the opposite of what you are saying unless you labelled the axis wrong? 😕

I don't think it does, maybe what is confusing is that the graph includes both onnx and llama.cpp results?

@elephantpanda
Copy link

elephantpanda commented Nov 24, 2024

I see so the ones labelled "onnx" are the ones you are running in genai and the ones labelled "llama" are the ones running in llama.cpp .
Yes sorry I got confused because Llama is also the name of an LLM. Yeah that looks very bad on the Mac. Guessing they haven't optimised it for Mac yet then . 😔 I have tried int4 on Windows DML and when it was partially working (version 0.4.0) it was very fast.

@ambroser53
Copy link

ambroser53 commented Nov 25, 2024

+1 to two of the issues raised. I am getting the exact same error on the fp16 version of my model:

Shape mismatch attempting to re-use buffer. {1,1,3072} != {1,808,3072}. Validate usage of dim_value (values should be > 0) and dim_param (all values with the same string should equate to the same size) in shapes in the model.

Plus I am also confused about the int8 support in the model builder. It seems it is supported to a certain extent:

io_dtype = TensorProto.FLOAT if precision in {"int8", "fp32"}

But similarly I get an error if I actually attempt to use it:

NotImplementedError: The int8 precision is not currently supported.

clarification would be helpful (especially as int8 can be supported through other means such as exporting with GPTQ or using Tensor-RTs Model Optimizer)

@xenova
Copy link

xenova commented Dec 3, 2024

+1 to the issues for fp16 versions of this model:

Error: Non-zero status code returned while running Cast node. Name:'InsertedPrecisionFreeCast_/model/layers.1/attn/v_proj/repeat_kv/Reshape_4/output_0' Status Message: /onnxruntime_src/onnxruntime/core/framework/op_kernel.cc:83 virtual OrtValue* onnxruntime::OpKernelContext::OutputMLValue(int, const onnxruntime::TensorShape&) status.IsOK() was false. Shape mismatch attempting to re-use buffer. {1,1,896} != {1,908,896}. Validate usage of dim_value (values should be > 0) and dim_param (all values with the same string should equate to the same size) in shapes in the model.

Subscribing to this thread 👍

@elephantpanda
Copy link

+1 I have updated from 0.4.0 to 0.5.2 and am experiencing about 3x slow down in speed. 😟 with the phi3 int4 onnx model. (I haven't tested other models).

@kunal-vaishnavi
Copy link
Contributor

python3 -m onnxruntime_genai.models.builder -m "Qwen/Qwen2.5-0.5B-Instruct" -o qwen -p fp16 -e cpu

Error: Non-zero status code returned while running Cast node. Name:'InsertedPrecisionFreeCast_/model/layers.1/attn/v_proj/repeat_kv/Reshape_4/output_0' Status Message: /onnxruntime_src/onnxruntime/core/framework/op_kernel.cc:83 virtual OrtValue* onnxruntime::OpKernelContext::OutputMLValue(int, const onnxruntime::TensorShape&) status.IsOK() was false. Shape mismatch attempting to re-use buffer. {1,1,896} != {1,908,896}. Validate usage of dim_value (values should be > 0) and dim_param (all values with the same string should equate to the same size) in shapes in the model.

FP16 CPU is not officially supported in ONNX Runtime. While an FP16 CPU model can be created, many of the model's operators still need to be implemented for FP16 CPU. When an operator is not implemented for FP16 CPU, a Cast node will be inserted to upcast to FP32 CPU to perform the computations. If you build with debug information or if you encounter any runtime errors, the InsertedPrecisionFreeCast_ prefix in a node's input names or output names will indicate this. The work is in progress to add FP16 CPU support for many operators. This will enable using the same ONNX model with the same precision on both CPU and GPU.

Thus, although the model builder can create a FP16 CPU model, the following message is printed at the beginning.

print("Valid precision + execution provider combinations are: FP32 CPU, FP32 CUDA, FP16 CUDA, FP16 DML, INT4 CPU, INT4 CUDA, INT4 DML")

On apple, the int4 precision is extremely slow

I have updated from 0.4.0 to 0.5.2 and am experiencing about 3x slow down in speed

Performance gains with INT4 precision should come from the ONNX Runtime version you have installed. This is because the quantized weights are handled within the ONNX model itself while the global inputs and outputs to the ONNX model are still in FP16 or FP32 precision. Since ONNX Runtime GenAI only manages the global inputs and outputs to the model and relies upon ONNX Runtime to run the model, it should not be causing this issue.

Can you try downgrading your ONNX Runtime version to an older version? When installing the newest ONNX Runtime GenAI package, it will try to upgrade you to ONNX Runtime v1.20.1 for you (see the requires_dist section here). For ONNX Runtime GenAI v0.4.0, the specified ONNX Runtime version was v1.19.0. It is possible there is an INT4 performance regression in the latest ONNX Runtime version.

Before re-benchmarking, you may need to re-build the ONNX model by commenting out this line and remove softcap = self.attention_attrs["softcap"] here since these two features are from the latest ONNX Runtime release and may not be part of older ONNX Runtime releases.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants