diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index 8e5dc006c..045216ecf 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -163,7 +163,11 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): # MatMul-specific variables is_lora = hasattr(config, "peft_type") and config.peft_type == "LORA" self.matmul_attrs = { - "use_lora": is_lora, # Use LoRA/QLoRA format + "use_lora": is_lora, # Use LoRA/QLoRA format + "lora": { # used to calculate scaling factors for LoRA/QLoRA + "alpha": config.lora_alpha if is_lora else 0, + "r": config.r if is_lora else 0 + } } # RotaryEmbedding-specific variables @@ -437,7 +441,7 @@ def save_model(self, out_dir): # Quantize ONNX model to desired precision # TODO: Replace by quantizing the MatMuls as they are created already_quantized_in_qdq_format = self.quant_type is not None and self.quant_attrs["use_qdq"] # Skip quantizing `MatMul` in `DequantizeLinear --> Transpose --> MatMul` path - if self.onnx_dtype == "int4" and not already_quantized_in_qdq_format: + if self.onnx_dtype == "int4" and not already_quantized_in_qdq_format and not self.matmul_attrs["use_lora"]: model = self.to_int4(model) # Save ONNX model with only one external data file and delete any existing duplicate copies @@ -714,7 +718,7 @@ def make_tanh(self, name, root_input, dtype, shape): self.make_value_info(output, dtype, shape=shape) def make_matmul(self, matmul, basename, root_input, **kwargs): - if hasattr(matmul, "base_layer"): + if hasattr(matmul, "lora_A"): # For LoRA `MatMul` return self.make_matmul_lora(matmul, basename, root_input, **kwargs) else: @@ -853,14 +857,23 @@ def make_matmul_lora(self, matmul, basename, root_input, **kwargs): matmul_A_name = self.make_matmul_op(matmul.lora_A.default, matmul_A_basename, root_input=root_input) lora_A = f"{matmul_A_name}/output_0" - matmul.lora_B.default.weight *= matmul.scaling["default"] + matmul.lora_B.default.weight *= (self.matmul_attrs["lora"]["alpha"] / self.matmul_attrs["lora"]["r"]) matmul_B_basename = "/".join(basename_parts[:-1] + ["lora_B"] + basename_parts[-1:]) matmul_B_name = self.make_matmul_op(matmul.lora_B.default, matmul_B_basename, root_input=lora_A) lora_B = f"{matmul_B_name}/output_0" - # Make regular MatMul path - last_dim = matmul.base_layer.weight.shape[0] - matmul_name = self.make_matmul_op(matmul.base_layer, basename, root_input, **kwargs) + if hasattr(matmul, "base_layer"): + # Make MatMul with base_layer + last_dim = matmul.base_layer.weight.shape[0] + matmul_name = self.make_matmul_op(matmul.base_layer, basename, root_input, **kwargs) + elif hasattr(matmul, "qweight"): + # Make quantized MatMul path + last_dim = matmul.qweight.shape[0] + matmul_name = self.make_matmul_op(matmul, basename, root_input, **kwargs) + else: + # Make regular MatMul path + last_dim = matmul.weight.shape[0] + matmul_name = self.make_matmul_op(matmul, basename, root_input, **kwargs) # Make LoRA Add node add_name = "/".join(basename_parts[:-1] + ["lora", "Add"]) @@ -2026,23 +2039,14 @@ def make_model(self, input_path): from onnxruntime_genai.models.quantized_model import QuantModel q_size = self.num_attn_heads * self.head_size kv_size = self.num_kv_heads * self.head_size - model = QuantModel.from_pretrained( - self.quant_type, - input_path, - self.quant_attrs["bits"], - self.quant_attrs["group_size"], - self.quant_attrs["use_g_idx"], - q_size, - kv_size, - self.intermediate_size, - self.num_layers, - ) + model = QuantModel.from_pretrained(self.quant_type, input_path, self.quant_attrs["bits"], self.quant_attrs["group_size"], self.quant_attrs["use_g_idx"], q_size, kv_size, self.intermediate_size, self.num_layers, self.extra_options.get("adapter_path", None)) else: # Load PyTorch model extra_kwargs = {"num_hidden_layers": self.num_layers} if "num_hidden_layers" in self.extra_options else {} model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path, cache_dir=self.cache_dir, token=self.hf_token, trust_remote_code=True, **extra_kwargs) - if "adapter_path" in self.extra_options: + # Checking for adapter path in extra_options when the base_model is not quantized + if "adapter_path" in self.extra_options and self.quant_type is None: from peft import PeftModel model = PeftModel.from_pretrained(model, self.extra_options["adapter_path"], cache_dir=self.cache_dir, token=self.hf_token) diff --git a/src/python/py/models/quantized_model.py b/src/python/py/models/quantized_model.py index ed0a5beb6..dffacfab4 100644 --- a/src/python/py/models/quantized_model.py +++ b/src/python/py/models/quantized_model.py @@ -52,39 +52,50 @@ def __init__(self): self.weight = None self.bias = None + @property + def default(self): + return self -class QuantizedAttention: + +class QuantizedLoRAModule(QuantizedTensorModule): def __init__(self, bits, group_size): - self.q_proj = QuantizedTensorModule(bits, group_size) - self.k_proj = QuantizedTensorModule(bits, group_size) - self.v_proj = QuantizedTensorModule(bits, group_size) - self.o_proj = QuantizedTensorModule(bits, group_size) + super().__init__(bits, group_size) + self.lora_A = TensorModule() + self.lora_B = TensorModule() + + +class QuantizedAttention: + def __init__(self, bits, group_size, is_lora): + self.q_proj = QuantizedTensorModule(bits, group_size) if not is_lora else QuantizedLoRAModule(bits, group_size) + self.k_proj = QuantizedTensorModule(bits, group_size) if not is_lora else QuantizedLoRAModule(bits, group_size) + self.v_proj = QuantizedTensorModule(bits, group_size) if not is_lora else QuantizedLoRAModule(bits, group_size) + self.o_proj = QuantizedTensorModule(bits, group_size) if not is_lora else QuantizedLoRAModule(bits, group_size) self.rotary_emb = TensorModule() class QuantizedMLP: - def __init__(self, bits, group_size): - self.gate_proj = QuantizedTensorModule(bits, group_size) - self.up_proj = QuantizedTensorModule(bits, group_size) - self.down_proj = QuantizedTensorModule(bits, group_size) - self.fc1 = QuantizedTensorModule(bits, group_size) - self.fc2 = QuantizedTensorModule(bits, group_size) + def __init__(self, bits, group_size, is_lora): + self.gate_proj = QuantizedTensorModule(bits, group_size) if not is_lora else QuantizedLoRAModule(bits, group_size) + self.up_proj = QuantizedTensorModule(bits, group_size) if not is_lora else QuantizedLoRAModule(bits, group_size) + self.down_proj = QuantizedTensorModule(bits, group_size) if not is_lora else QuantizedLoRAModule(bits, group_size) + self.fc1 = QuantizedTensorModule(bits, group_size) if not is_lora else QuantizedLoRAModule(bits, group_size) + self.fc2 = QuantizedTensorModule(bits, group_size) if not is_lora else QuantizedLoRAModule(bits, group_size) class QuantizedDecoderLayer: - def __init__(self, layer_id, bits, group_size): + def __init__(self, layer_id, bits, group_size, is_lora): self.layer_id = layer_id self.input_layernorm = TensorModule() - self.self_attn = QuantizedAttention(bits, group_size) + self.self_attn = QuantizedAttention(bits, group_size, is_lora) self.post_attention_layernorm = TensorModule() - self.mlp = QuantizedMLP(bits, group_size) + self.mlp = QuantizedMLP(bits, group_size, is_lora) def is_empty(self): return self.input_layernorm.weight is None class QuantizedModel: - def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size, num_layers): + def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size, num_layers, adapter_path): self.quant_type = quant_type self.embedding = TensorModule() self.final_norm = TensorModule() @@ -92,11 +103,33 @@ def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, in self.layers = {} self.num_layers = num_layers + self.is_lora = adapter_path is not None + + self.map_to_modules(quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size) + + if self.is_lora: + self.map_to_modules(quant_type, adapter_path, bits, group_size, q_size, kv_size, intermediate_size) + + # Set LM head weights + biases if not already set + if isinstance(self.lm_head, TensorModule) and self.lm_head.weight is None: + # Embedding and LM head share same weights + biases (lm_head.weight == embedding.weight and lm_head.bias == embedding.bias) + self.lm_head.weight = self.embedding.weight + if self.lm_head.bias is not None: + self.lm_head.bias = self.embedding.bias + + # Sort list of layers by layer id + self.layers = list(self.layers.values()) + self.layers.sort(key=lambda m: m.layer_id) + + # Set properties of each layer based on quantization type + self.set_properties() + + def map_to_modules(self, quant_type, file_path, bits, group_size, q_size, kv_size, intermediate_size): layer_id = 0 - for weight_file in os.listdir(input_path): + for weight_file in os.listdir(file_path): if weight_file.endswith(".safetensors"): - module = self.layers.setdefault(layer_id, QuantizedDecoderLayer(layer_id, bits, group_size)) - weights = load_file(os.path.join(input_path, weight_file)) + module = self.layers.setdefault(layer_id, QuantizedDecoderLayer(layer_id, bits, group_size, self.is_lora)) + weights = load_file(os.path.join(file_path, weight_file)) # Map weights to modules for name, tensor in weights.items(): @@ -133,11 +166,13 @@ def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, in if name.startswith("transformer.encoder"): # Chatglm3, e.g., transformer.encoder.layers.0.input_layernorm.weight name = name.replace("transformer.encoder", "model") + if name.startswith("base_model.model.model"): + name = name.replace("base_model.model.model", "model") curr_layer_id = int(name.split(".")[2]) if curr_layer_id != layer_id: # Switch layer module used layer_id = curr_layer_id - module = self.layers.setdefault(layer_id, QuantizedDecoderLayer(layer_id, bits, group_size)) + module = self.layers.setdefault(layer_id, QuantizedDecoderLayer(layer_id, bits, group_size, self.is_lora)) # Map weights and biases of norm, attention, and feed-forward network # Graph order is input_layernorm --> q_proj/k_proj/v_proj --> o_proj --> post_attention_layernorm --> gate_proj/up_proj --> down_proj @@ -334,23 +369,51 @@ def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, in # model.layers.layer_id.mlp.dense_h_to_4h.bias module.mlp.gate_proj.bias = tensor[: intermediate_size] module.mlp.down_proj.bias = tensor[intermediate_size: ] + elif bool(re.match(r"^model.layers\.\d+\.self_attn.q_proj.lora_A\.weight$", name)): + # model.layers.layer_id.self_attn.q_proj.lora_A.weight + module.self_attn.q_proj.lora_A.weight = tensor + elif bool(re.match(r"^model.layers\.\d+\.self_attn.q_proj.lora_B\.weight$", name)): + # model.layers.layer_id.self_attn.q_proj.lora_B.weight + module.self_attn.q_proj.lora_B.weight = tensor + elif bool(re.match(r"^model.layers\.\d+\.self_attn.k_proj.lora_A\.weight$", name)): + # model.layers.layer_id.self_attn.k_proj.lora_A.weight + module.self_attn.k_proj.lora_A.weight = tensor + elif bool(re.match(r"^model.layers\.\d+\.self_attn.k_proj.lora_B\.weight$", name)): + # model.layers.layer_id.self_attn.k_proj.lora_B.weight + module.self_attn.k_proj.lora_B.weight = tensor + elif bool(re.match(r"^model.layers\.\d+\.self_attn.v_proj.lora_A\.weight$", name)): + # model.layers.layer_id.self_attn.v_proj.lora_A.weight + module.self_attn.v_proj.lora_A.weight = tensor + elif bool(re.match(r"^model.layers\.\d+\.self_attn.v_proj.lora_B\.weight$", name)): + # model.layers.layer_id.self_attn.v_proj.lora_B.weight + module.self_attn.v_proj.lora_B.weight = tensor + elif bool(re.match(r"^model.layers\.\d+\.self_attn.o_proj.lora_A\.weight$", name)): + # model.layers.layer_id.self_attn.o_proj.lora_A.weight + module.self_attn.o_proj.lora_A.weight = tensor + elif bool(re.match(r"^model.layers\.\d+\.self_attn.o_proj.lora_B\.weight$", name)): + # model.layers.layer_id.self_attn.o_proj.lora_B.weight + module.self_attn.o_proj.lora_B.weight = tensor + elif bool(re.match(r"^model.layers\.\d+\.mlp.gate_proj.lora_A\.weight$", name)): + # model.layers.layer_id.mlp.gate_proj.lora_A.weight + module.mlp.gate_proj.lora_A.weight = tensor + elif bool(re.match(r"^model.layers\.\d+\.mlp.gate_proj.lora_B\.weight$", name)): + # model.layers.layer_id.mlp.gate_proj.lora_B.weight + module.mlp.gate_proj.lora_B.weight = tensor + elif bool(re.match(r"^model.layers\.\d+\.mlp.up_proj.lora_A\.weight$", name)): + # model.layers.layer_id.mlp.up_proj.lora_A.weight + module.mlp.up_proj.lora_A.weight = tensor + elif bool(re.match(r"^model.layers\.\d+\.mlp.up_proj.lora_B\.weight$", name)): + # model.layers.layer_id.mlp.up_proj.lora_B.weight + module.mlp.up_proj.lora_B.weight = tensor + elif bool(re.match(r"^model.layers\.\d+\.mlp.down_proj.lora_A\.weight$", name)): + # model.layers.layer_id.mlp.down_proj.lora_A.weight + module.mlp.down_proj.lora_A.weight = tensor + elif bool(re.match(r"^model.layers\.\d+\.mlp.down_proj.lora_B\.weight$", name)): + # model.layers.layer_id.mlp.down_proj.lora_B.weight + module.mlp.down_proj.lora_B.weight = tensor else: raise NotImplementedError(f"{name} in your quantized model is not recognized.") - # Set LM head weights + biases if not already set - if isinstance(self.lm_head, TensorModule) and self.lm_head.weight is None: - # Embedding and LM head share same weights + biases (lm_head.weight == embedding.weight and lm_head.bias == embedding.bias) - self.lm_head.weight = self.embedding.weight - if self.lm_head.bias is not None: - self.lm_head.bias = self.embedding.bias - - # Sort list of layers by layer id - self.layers = list(self.layers.values()) - self.layers.sort(key=lambda m: m.layer_id) - - # Set properties of each layer based on quantization type - self.set_properties() - def _initialize_quantized_lm_head(self, bits, group_size): """ Initialize `QuantizedTensorModule` for LM head if not already set @@ -376,6 +439,7 @@ def set_properties(self): self.lm_head.in_features = self.lm_head.g_idx.shape[0] else: raise NotImplementedError(f"The {self.quant_type} quantization method is not recognized.") + for module in self.layers: if self.quant_type == "awq": # Set in_features and out_features @@ -393,6 +457,36 @@ def set_properties(self): module.mlp.up_proj.in_features = module.mlp.up_proj.qweight.shape[0] module.mlp.down_proj.out_features = module.mlp.down_proj.scales.shape[1] module.mlp.down_proj.in_features = module.mlp.down_proj.qweight.shape[0] + if self.is_lora: + module.self_attn.q_proj.lora_A.out_features = module.self_attn.q_proj.lora_A.weight.shape[1] + module.self_attn.q_proj.lora_A.in_features = module.self_attn.q_proj.lora_A.weight.shape[0] + module.self_attn.q_proj.lora_B.out_features = module.self_attn.q_proj.lora_B.weight.shape[1] + module.self_attn.q_proj.lora_B.in_features = module.self_attn.q_proj.lora_B.weight.shape[0] + module.self_attn.k_proj.lora_A.out_features = module.self_attn.k_proj.lora_A.weight.shape[1] + module.self_attn.k_proj.lora_A.in_features = module.self_attn.k_proj.lora_A.weight.shape[0] + module.self_attn.k_proj.lora_B.out_features = module.self_attn.k_proj.lora_B.weight.shape[1] + module.self_attn.k_proj.lora_B.in_features = module.self_attn.k_proj.lora_B.weight.shape[0] + module.self_attn.v_proj.lora_A.out_features = module.self_attn.v_proj.lora_A.weight.shape[1] + module.self_attn.v_proj.lora_A.in_features = module.self_attn.v_proj.lora_A.weight.shape[0] + module.self_attn.v_proj.lora_B.out_features = module.self_attn.v_proj.lora_B.weight.shape[1] + module.self_attn.v_proj.lora_B.in_features = module.self_attn.v_proj.lora_B.weight.shape[0] + module.self_attn.o_proj.lora_A.out_features = module.self_attn.o_proj.lora_A.weight.shape[1] + module.self_attn.o_proj.lora_A.in_features = module.self_attn.o_proj.lora_A.weight.shape[0] + module.self_attn.o_proj.lora_B.out_features = module.self_attn.o_proj.lora_B.weight.shape[1] + module.self_attn.o_proj.lora_B.in_features = module.self_attn.o_proj.lora_B.weight.shape[0] + module.mlp.gate_proj.lora_A.out_features = module.mlp.gate_proj.lora_A.weight.shape[1] + module.mlp.gate_proj.lora_A.in_features = module.mlp.gate_proj.lora_A.weight.shape[0] + module.mlp.gate_proj.lora_B.out_features = module.mlp.gate_proj.lora_B.weight.shape[1] + module.mlp.gate_proj.lora_B.in_features = module.mlp.gate_proj.lora_B.weight.shape[0] + module.mlp.up_proj.lora_A.out_features = module.mlp.up_proj.lora_A.weight.shape[1] + module.mlp.up_proj.lora_A.in_features = module.mlp.up_proj.lora_A.weight.shape[0] + module.mlp.up_proj.lora_B.out_features = module.mlp.up_proj.lora_B.weight.shape[1] + module.mlp.up_proj.lora_B.in_features = module.mlp.up_proj.lora_B.weight.shape[0] + module.mlp.down_proj.lora_A.out_features = module.mlp.down_proj.lora_A.weight.shape[1] + module.mlp.down_proj.lora_A.in_features = module.mlp.down_proj.lora_A.weight.shape[0] + module.mlp.down_proj.lora_B.out_features = module.mlp.down_proj.lora_B.weight.shape[1] + module.mlp.down_proj.lora_B.in_features = module.mlp.down_proj.lora_B.weight.shape[0] + # Set g_idx if not already set module.self_attn.q_proj.g_idx = module.self_attn.q_proj.g_idx if module.self_attn.q_proj.g_idx is not None else torch.tensor([i // module.self_attn.q_proj.group_size for i in range(module.self_attn.q_proj.in_features)], dtype=torch.int32) @@ -419,6 +513,35 @@ def set_properties(self): module.mlp.up_proj.in_features = module.mlp.up_proj.g_idx.shape[0] module.mlp.down_proj.out_features = module.mlp.down_proj.qweight.shape[1] module.mlp.down_proj.in_features = module.mlp.down_proj.g_idx.shape[0] + if self.is_lora: + module.self_attn.q_proj.lora_A.out_features = module.self_attn.q_proj.lora_A.weight.shape[1] + module.self_attn.q_proj.lora_A.in_features = module.self_attn.q_proj.lora_A.weight.shape[0] + module.self_attn.q_proj.lora_B.out_features = module.self_attn.q_proj.lora_B.weight.shape[1] + module.self_attn.q_proj.lora_B.in_features = module.self_attn.q_proj.lora_B.weight.shape[0] + module.self_attn.k_proj.lora_A.out_features = module.self_attn.k_proj.lora_A.weight.shape[1] + module.self_attn.k_proj.lora_A.in_features = module.self_attn.k_proj.lora_A.weight.shape[0] + module.self_attn.k_proj.lora_B.out_features = module.self_attn.k_proj.lora_B.weight.shape[1] + module.self_attn.k_proj.lora_B.in_features = module.self_attn.k_proj.lora_B.weight.shape[0] + module.self_attn.v_proj.lora_A.out_features = module.self_attn.v_proj.lora_A.weight.shape[1] + module.self_attn.v_proj.lora_A.in_features = module.self_attn.v_proj.lora_A.weight.shape[0] + module.self_attn.v_proj.lora_B.out_features = module.self_attn.v_proj.lora_B.weight.shape[1] + module.self_attn.v_proj.lora_B.in_features = module.self_attn.v_proj.lora_B.weight.shape[0] + module.self_attn.o_proj.lora_A.out_features = module.self_attn.o_proj.lora_A.weight.shape[1] + module.self_attn.o_proj.lora_A.in_features = module.self_attn.o_proj.lora_A.weight.shape[0] + module.self_attn.o_proj.lora_B.out_features = module.self_attn.o_proj.lora_B.weight.shape[1] + module.self_attn.o_proj.lora_B.in_features = module.self_attn.o_proj.lora_B.weight.shape[0] + module.mlp.gate_proj.lora_A.out_features = module.mlp.gate_proj.lora_A.weight.shape[1] + module.mlp.gate_proj.lora_A.in_features = module.mlp.gate_proj.lora_A.weight.shape[0] + module.mlp.gate_proj.lora_B.out_features = module.mlp.gate_proj.lora_B.weight.shape[1] + module.mlp.gate_proj.lora_B.in_features = module.mlp.gate_proj.lora_B.weight.shape[0] + module.mlp.up_proj.lora_A.out_features = module.mlp.up_proj.lora_A.weight.shape[1] + module.mlp.up_proj.lora_A.in_features = module.mlp.up_proj.lora_A.weight.shape[0] + module.mlp.up_proj.lora_B.out_features = module.mlp.up_proj.lora_B.weight.shape[1] + module.mlp.up_proj.lora_B.in_features = module.mlp.up_proj.lora_B.weight.shape[0] + module.mlp.down_proj.lora_A.out_features = module.mlp.down_proj.lora_A.weight.shape[1] + module.mlp.down_proj.lora_A.in_features = module.mlp.down_proj.lora_A.weight.shape[0] + module.mlp.down_proj.lora_B.out_features = module.mlp.down_proj.lora_B.weight.shape[1] + module.mlp.down_proj.lora_B.in_features = module.mlp.down_proj.lora_B.weight.shape[0] else: raise NotImplementedError(f"The {self.quant_type} quantization method is not recognized.") @@ -587,8 +710,8 @@ def pack_ort_format(self, module, intweight): class AWQModel(QuantizedModel): - def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size, num_layers): - super().__init__(quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size, num_layers) + def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size, num_layers, adapter_path): + super().__init__(quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size, num_layers, adapter_path) # Unpack and repack all `QuantizedTensorModule` classes in model for i, layer in enumerate(self.layers): @@ -662,8 +785,8 @@ def reverse_reorder_tensor(self, tensor, bits): class GPTQModel(QuantizedModel): - def __init__(self, quant_type, input_path, bits, group_size, use_g_idx, q_size, kv_size, intermediate_size, num_layers): - super().__init__(quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size, num_layers) + def __init__(self, quant_type, input_path, bits, group_size, use_g_idx, q_size, kv_size, intermediate_size, num_layers, adapter_path): + super().__init__(quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size, num_layers, adapter_path) # Unpack and repack all `QuantizedTensorModule` classes in model for i, layer in enumerate(self.layers): @@ -729,16 +852,16 @@ def __init__(self, module): class QuantModel: @staticmethod - def from_pretrained(quant_type, input_path, bits, group_size, use_g_idx, q_size, kv_size, intermediate_size, num_layers): + def from_pretrained(quant_type, input_path, bits, group_size, use_g_idx, q_size, kv_size, intermediate_size, num_layers, adapter_path): """ Unpack quantized weights in PyTorch models, store them in a standard format, and repack them into ONNX Runtime's format. Also performs any pre-processing and post-processing when unpacking the quantized weights. """ if quant_type == "awq": - model = AWQModel(quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size, num_layers) + model = AWQModel(quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size, num_layers, adapter_path) elif quant_type == "gptq": - model = GPTQModel(quant_type, input_path, bits, group_size, use_g_idx, q_size, kv_size, intermediate_size, num_layers) + model = GPTQModel(quant_type, input_path, bits, group_size, use_g_idx, q_size, kv_size, intermediate_size, num_layers, adapter_path) else: raise NotImplementedError(f"The {quant_type} quantized model is not currently supported.")