Skip to content

Latest commit

 

History

History
267 lines (231 loc) · 13.4 KB

PT_FP8Quant.md

File metadata and controls

267 lines (231 loc) · 13.4 KB

FP8 Quantization

  1. Introduction
  2. Supported Parameters
  3. Get Start with FP8 Quantization
  4. Optimum-habana LLM example
  5. VLLM example

Introduction

Float point 8 (FP8) is a promising data type for low precision quantization which provides a data distribution that is completely different from INT8 and it's shown as below.

Intel Gaudi2, also known as HPU, provides this data type capability for low precision quantization, which includes E4M3 and E5M2. For more information about these two data type, please refer to link.

Intel Neural Compressor provides general quantization APIs to leverage HPU FP8 capability. with simple with lower memory usage and lower compute cost, 8 bit model

Supported Parameters

Attribute Description Values
fp8_config The target data type of FP8 quantization. E4M3 (default) - As Fig. 2
E5M2 - As Fig. 1.
hp_dtype The high precision data type of non-FP8 operators. bf16 (default) - torch.bfloat16
fp16 - torch.float16.
fp32 - torch.float32.
observer The observer to measure the statistics. maxabs (default), saves all tensors to files.
allowlist List of nn.Module names or types to quantize. When setting an empty list, all the supported modules will be quantized by default. See Supported Modules. Not setting the list at all is not recommended as it will set the allowlist to these modules only: torch.nn.Linear, torch.nn.Conv2d, and BMM. Default = {'names': [], 'types': FP8_WHITE_LIST}
blocklist List of nn.Module names or types not to quantize. Defaults to empty list, so you may omit it from the config file. Default = {'names': [], 'types': ()}
mode The mode, measure or quantize, to run HQT with. MEASURE - Measure statistics of all modules and emit the results to dump_stats_path.
QUANTIZE - Quantize and run the model according to the provided measurements.
AUTO (default) - Select from [MEASURE, QUANTIZE] automatically.
dump_stats_path The path to save and load the measurements. The path is created up until the level before last "/". The string after the last / will be used as prefix to all the measurement files that will be created. Default = "./hqt_output/measure"
scale_method The method for calculating the scale from the measurement. - unit_scale - Always use scale of 1.
- hw_aligned_single_scale - Always use scale that's aligned to the corresponding HW accelerated scale.
- maxabs_hw (default) - Scale is calculated to stretch/compress the maxabs measurement to the full-scale of FP8 and then aligned to the corresponding HW accelerated scale.
- maxabs_pow2 - Scale is calculated to stretch/compress the maxabs measurement to the full-scale of FP8 and then rounded to the power of 2.
- maxabs_hw_opt_weight - Scale of model params (weights) is chosen as the scale that provides minimal mean-square-error between quantized and non-quantized weights, from all possible HW accelerated scales. Scale of activations is calculated the same as maxabs_hw.
- act_maxabs_pow2_weights_pcs_opt_pow2 - Scale of model params (weights) is calculated per-channel of the params tensor. The scale per-channel is calculated the same as maxabs_hw_opt_weight. Scale of activations is calculated the same as maxabs_pow2.
- act_maxabs_hw_weights_pcs_maxabs_pow2 - Scale of model params (weights) is calculated per-channel of the params tensor. The scale per-channel is calculated the same as maxabs_pow2. Scale of activations is calculated the same as maxabs_hw.
measure_exclude If this attribute is not defined, the default is OUTPUT. Since most models do not require measuring output tensors, you can exclude it to speed up the measurement process. NONE - All tensors are measured.
OUTPUT (default) - Excludes measurement of output tensors.

Get Start with FP8 Quantization

Demo Usage
Computer vision example

Optimum-habana LLM example

Overview

Optimum is an extension of Transformers that provides a set of performance optimization tools to train and run models on targeted hardware with maximum efficiency.
Optimum-habana is the interface between the Transformers, Diffusers libraries and Intel Gaudi AI Accelerators (HPU). It provides higher performance based on modified modeling files, and utilizes Intel Neural Compressor for FP8 quantization internally, running-with-fp8

Installation

Refer to optimum-habana, install-the-library-and-get-example-scripts
Option to install from source,

$ git clone https://github.com/huggingface/optimum-habana
$ cd optimum-habana && git checkout v1.14.0 (change the version)
$ pip install -e .
$ pip install git+https://github.com/HabanaAI/[email protected]
$ cd examples/text-generation
$ pip install -r requirements.txt
$ pip install -r requirements_lm_eval.txt  (Option)

Check neural_compressor code

optimum-habana/examples/text-generation/utils.py

initialize_model() -> setup_model() -> setup_quantization() -> FP8Config/prepare()/convert()

FP8 KV cache

Introduction: kv-cache-quantization in huggingface transformers

BF16 KVCache Code -> Modeling_all_models.py -> KVCache()

FP8 KVCache code trace with neural compressor support, for example Llama models,

optimum-habana/optimum/habana/transformers/models/llama/modeling_llama.py

GaudiLlamaForCausalLM() -> self.model()

GaudiLlamaModel() -> forward() -> decoder_layer() -> GaudiLlamaDecoderLayer() forward() -> pre_attn() -> pre_attn_forward() -> self.k_cache.update

neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py

PatchedKVCache() -> update()
PatchedModuleFusedSDPA()

Models list which support FP8 KV Cache,

microsoft/Phi-3-mini-4k-instruct
bigcode/starcoder2-3b
Qwen/Qwen2.5-7B-Instruct|
meta-llama/Llama-3.2-3B-Instruct
tiiuae/falcon-7b-instruct
mistralai/Mixtral-8x7B-Instruct-v0.1
EleutherAI/gpt-j-6b
mistralai/Mistral-Nemo-Instruct-2407
...

Running with FP8

Refer to here.
Change "--model_name_or_path" to be your model like
"meta-llama/Llama-3.1-8B-Instruct",
"Qwen/Qwen2.5-7B-Instruct", or
"mistralai/Mixtral-8x7B-Instruct-v0.1" and so on.
"--use_kv_cache" is to enable FP8 KV cache.

Profiling

Add "--profiling_warmup_steps 5 --profiling_steps 2 --profiling_record_shapes" as args in the end of commandline of run_generation.py.
Refer to torch.profiler.ProfilerActivity.HPU.

FP8 Accuracy

"lm_eval.tasks", "lm_eval.evaluator", "lm_eval" are installed from the above requirements_lm_eval.txt. The tasks can be set and the default is ["hellaswag", "lambada_openai", "piqa", "winogrande"], more info

Llama-2-7b-hf fp8 & fp8 KVCache bf16 w/ bf16 KVCache
hellaswag 0.5691097390957977 0.5704043019318861
lambada_openai 0.7360760721909567 0.7372404424607025
piqa 0.7850924918389554 0.7818280739934712
winogrande 0.6929755327545383 0.6929755327545383
Qwen2.5-7B-Instruct fp8 & fp8 KVCache bf16 w/ bf16 KVCache
hellaswag 0.2539334793865764 0.2539334793865764
lambada_openai 0.0 0.0
piqa 0.5391730141458106 0.5391730141458106
winogrande 0.4956590370955012 0.4956590370955012
Llama-3.1-8B-Instruct fp8 & fp8 KVCache bf16 w/ bf16 KVCache
hellaswag 0.5934076877116112 0.5975901214897431
lambada_openai 0.7230739375121289 0.7255967397632447
piqa 0.7932535364526659 0.8030467899891186
winogrande 0.7434885556432518 0.7371744277821626
Mixtral-8x7B-Instruct-v0.1 fp8 & fp8 KVCache bf16 w/ bf16 KVCache
hellaswag 0.25323640709022105 0.25323640709022105
lambada_openai 0.0 0.0
piqa 0.528835690968444 0.528835690968444
winogrande 0.4956590370955012 0.4956590370955012

VLLM example

Overview

Installation

Refer to Habana vllm-fork to install.
Option to install vllm-hpu-extension, neural_compressor and vllm from the source,

$ git clone https://github.com/HabanaAI/vllm-fork.git
$ cd vllm-fork
$ pip install -r requirements-hpu.txt
$ python setup.py develop --user

## Check
$ pip list |grep vllm
vllm                              0.6.3.dev1122+g2f43ebf5.d20241121.gaudi118 /home/fengding/vllm-fork
vllm-hpu-extension                0.1

## Validation
$ VLLM_SKIP_WARMUP=true python3 examples/offline_inference.py
......
Prompt: 'Hello, my name is', Generated text: ' Kelly and I have a job to do.\nI need someone to come over'
Prompt: 'The president of the United States is', Generated text: ' facing a sharp criticism of his handling of the coronavirus pandemic, including'
Prompt: 'The capital of France is', Generated text: ' the capital of the Socialist Party of France (SPF), with its state-'
Prompt: 'The future of AI is', Generated text: " in what's coming, not what's coming.\nI don't know what"

Run FP8 calibration

Refer to vllm-hpu-extension->calibration

$ git clone https://github.com/HabanaAI/vllm-hpu-extension
$ cd vllm-hpu-extension/calibration

# For Llama-3.1.8B-Instruct
$ ./calibrate_model.sh -m meta-llama/Llama-3.1-8B-Instruct -d /home/fengding/processed-data.pkl -o ./output_llama3.1.8b.Instruct -b 128 -t 1 -l 128
    ## Generate scale factors in ./output_llama3.1.8b.Instruct

Start vllm server

$ cd vllm-fork/

$ PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
PT_HPU_WEIGHT_SHARING=0 \
VLLM_CONTIGUOUS_PA=true \
VLLM_SKIP_WARMUP=true \
QUANT_CONFIG=output_llama3.1.8b.Instruct/maxabs_quant_g2.json \
python3 -m vllm.entrypoints.openai.api_server \
--model meta-llama/Llama-3.1-8B-Instruct \
--port 8080 \
--gpu-memory-utilization 0.9 \
--tensor-parallel-size 1 \
--disable-log-requests \
--block-size 128 \
--quantization inc \
--kv-cache-dtype fp8_inc \
--device hpu \
--weights-load-device cpu \
--dtype bfloat16 \
--num_scheduler_steps 16 2>&1 > vllm_serving.log &

Refer to vllm-fork->README_GAUDI.md for more details.

Start client to test

$ curl --noproxy "*" http://localhost:8080/v1/completions -H "Content-Type: application/json" -d '{"model": "meta-llama/Llama-3.1-8B-Instruct", "prompt": "San Francisco is a", "max_tokens": 100}'

Run benchmark

python benchmarks/benchmark_serving.py \
--backend vllm \
--model meta-llama/Llama-3.1-8B-Instruct  \
--dataset-name sonnet \
--dataset-path benchmarks/sonnet.txt \
--request-rate 128 \
--num-prompts 128 \
--port 8080 \
--sonnet-input-len 128 \
--sonnet-output-len 128 \
--sonnet-prefix-len 100

FP8 KV cache

Code trace

vllm-fork/vllm/attention/backends/hpu_attn.py

from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache HPUAttentionImpl() -> self.k_cache() / self.v_cache()

neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py

PatchedVLLMKVCache()

neural_compressor/torch/algorithms/fp8_quant/common.py

"VLLMKVCache": ModuleInfo("kv_cache", PatchedVLLMKVCache)