Skip to content

Commit

Permalink
Merge pull request #159 from mistralai/add_lora
Browse files Browse the repository at this point in the history
Add LoRA
  • Loading branch information
pierrestock authored May 24, 2024
2 parents b2a519b + 788f67e commit cd06b0d
Show file tree
Hide file tree
Showing 7 changed files with 850 additions and 83 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,6 @@ Instructions to run the image can be found in the [official documentation](https
- Use Mistral models on [Mistral AI official API](https://console.mistral.ai/) (La Plateforme)
- Use Mistral models via [cloud providers](https://docs.mistral.ai/deployment/cloud/overview/)

## References

[1]: [LoRA](https://arxiv.org/abs/2106.09685): Low-Rank Adaptation of Large Language Models, Hu et al. 2021
717 changes: 646 additions & 71 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "mistral_inference"
version = "v1.0.4"
version = "v1.1.0"
description = ""
authors = ["bam4d <[email protected]>"]
readme = "README.md"
Expand All @@ -24,7 +24,7 @@ exclude = ["docs", "tools", "build"]

[tool.poetry.dependencies]
python = "^3.9.10"
xformers = ">=0.0.25"
xformers = ">=0.0.24"
simple-parsing = ">=0.1.5"
fire = ">=0.6.0"
mistral_common = "^1.0.0"
Expand Down
2 changes: 1 addition & 1 deletion src/mistral_inference/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.0.4"
__version__ = "1.1.0"
166 changes: 166 additions & 0 deletions src/mistral_inference/lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, NamedTuple, Union

import safetensors.torch
import torch
import torch.nn as nn
from simple_parsing.helpers import Serializable


@dataclass
class LoraArgs(Serializable):
rank: int
scaling: float

def __post_init__(self):
assert self.rank > 0
assert self.scaling > 0.0


class LoRALinear(nn.Module):
"""
Implementation of:
- LoRA: https://arxiv.org/abs/2106.09685
Notes:
- Freezing is handled at network level, not layer level.
- Scaling factor controls relative importance of LoRA skip
connection versus original frozen weight. General guidance is
to keep it to 2.0 and sweep over learning rate when changing
the rank.
"""

def __init__(
self,
in_features: int,
out_features: int,
rank: int,
scaling: float,
bias: bool = False,
):
super().__init__()

self.in_features = in_features
self.out_features = out_features
assert not bias
self.bias = bias
self.rank = rank
self.scaling = scaling

self.lora_A = nn.Linear(
self.in_features,
self.rank,
bias=self.bias,
)
self.lora_B = nn.Linear(
self.rank,
self.out_features,
bias=self.bias,
)

self.linear = nn.Linear(self.in_features, self.out_features, bias=self.bias)

# make sure no LoRA weights are marked as "missing" in load_state_dict
def ignore_missing_keys(m: nn.Module, incompatible_keys: NamedTuple):
incompatible_keys.missing_keys[:] = [] # type: ignore

self.register_load_state_dict_post_hook(ignore_missing_keys)

def forward(self, x: torch.Tensor):
lora = self.lora_B(self.lora_A(x))
return self.linear(x) + lora * self.scaling

def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
key_name = prefix + "weight"

# full checkpoint
if key_name in state_dict:
w_ref = state_dict[key_name]

# load frozen weights
state_dict = {
"linear.weight": w_ref,
"lora_A.weight": torch.zeros_like(
self.lora_A.weight, device=w_ref.device, dtype=w_ref.dtype
),
"lora_B.weight": torch.zeros_like(
self.lora_B.weight, device=w_ref.device, dtype=w_ref.dtype
),
}
self.load_state_dict(state_dict, assign=True, strict=True)


class LoRALoaderMixin:
def load_lora(self, lora_path: Union[Path, str], scaling: float = 2.0):
"""Loads LoRA checkpoint"""

lora_path = Path(lora_path)
assert lora_path.is_file(), f"{lora_path} does not exist or is not a file"

state_dict = safetensors.torch.load_file(lora_path)

self._load_lora_state_dict(state_dict, scaling=scaling)

def _load_lora_state_dict(
self, lora_state_dict: Dict[str, torch.Tensor], scaling: float = 2.0
):
"""Loads LoRA state_dict"""

lora_dtypes = set([p.dtype for p in lora_state_dict.values()])
assert (
len(lora_dtypes) == 1
), f"LoRA weights have multipe different dtypes {lora_dtypes}. All weights need to have the same dtype"
lora_dtype = lora_dtypes.pop()
assert (
lora_dtype == self.dtype
), f"LoRA weights dtype differs from model's dtype {lora_dtype} != {self.dtype}"
assert all("lora" in key for key in lora_state_dict.keys())

# move tensors to device
lora_state_dict = {k: v.to(self.device) for k, v in lora_state_dict.items()}

state_dict = self.state_dict()

if self.args.lora is None:
logging.info("Loading and merging LoRA weights...")

# replace every nn.Linear with a LoRALinear with 'meta' device except the output layer
named_modules = dict(self.named_modules())
for name, module in named_modules.items():
if isinstance(module, nn.Linear) and name != "output":
layer_id = name.split(".")[1]
if layer_id not in self.layers:
logging.debug(
"Skipping parameter %s at pipeline rank %d",
name,
self.pipeline_rank,
)
else:
weight = (
module.weight
+ (
lora_state_dict[name + ".lora_B.weight"]
@ lora_state_dict[name + ".lora_A.weight"]
)
* scaling
)

state_dict[name + ".weight"] = weight
else:
logging.info("Loading LoRA weights...")
for k, v in lora_state_dict.items():
state_dict.update(lora_state_dict)

layer_id = k.split(".")[1]
if layer_id in self.layers:
state_dict[k] = v
else:
logging.debug(
"Skipping parameter %s at pipeline rank %d",
k,
self.pipeline_rank,
)

self.load_state_dict(state_dict, strict=True)
12 changes: 11 additions & 1 deletion src/mistral_inference/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os
from pathlib import Path
from typing import List
from typing import List, Optional

import fire # type: ignore
import torch
Expand Down Expand Up @@ -46,6 +46,7 @@ def interactive(
temperature: float = 0.7,
num_pipeline_ranks: int = 1,
instruct: bool = False,
lora_path: Optional[str] = None,
) -> None:
if is_torchrun():
torch.distributed.init_process_group()
Expand All @@ -64,6 +65,10 @@ def interactive(
Path(model_path), max_batch_size=3, num_pipeline_ranks=num_pipeline_ranks
)

# load LoRA
if lora_path is not None:
transformer.load_lora(Path(lora_path))

prompt: str = ""
messages: List[UserMessage | AssistantMessage] = []

Expand Down Expand Up @@ -117,6 +122,7 @@ def demo(
model_path: str,
max_tokens: int = 35,
temperature: float = 0,
lora_path: Optional[str] = None,
) -> None:
if is_torchrun():
torch.distributed.init_process_group()
Expand All @@ -131,6 +137,10 @@ def demo(
transformer = Transformer.from_folder(
Path(model_path), max_batch_size=3, num_pipeline_ranks=num_pipeline_ranks
)
# load LoRA
if lora_path is not None:
transformer.load_lora(Path(lora_path))

mistral_tokenizer: MistralTokenizer = load_tokenizer(Path(model_path))
tokenizer: Tokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer

Expand Down
29 changes: 21 additions & 8 deletions src/mistral_inference/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import math
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Any, List, Mapping, Optional, Tuple, Union

Expand All @@ -16,6 +17,7 @@
CacheInputMetadata,
CacheView,
)
from mistral_inference.lora import LoraArgs, LoRALinear, LoRALoaderMixin
from mistral_inference.moe import MoeArgs, MoeLayer
from mistral_inference.rope import apply_rotary_emb, precompute_freqs_cis

Expand All @@ -37,6 +39,8 @@ class ModelArgs(Serializable):
rope_theta: Optional[float] = None
# If this is set, we will use MoE layers instead of dense layers.
moe: Optional[MoeArgs] = None
# If this is set, we will load LoRA linear layers instead of linear layers.
lora: Optional[LoraArgs] = None


@dataclass
Expand All @@ -61,6 +65,13 @@ def repeat_kv(
return keys, values


def maybe_lora(args: ModelArgs) -> Union[nn.Linear, LoRALinear]:
if args.lora is None:
return nn.Linear
else:
return partial(LoRALinear, rank=args.lora.rank, scaling=args.lora.scaling)


class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
Expand All @@ -74,10 +85,11 @@ def __init__(self, args: ModelArgs):

self.scale = self.args.head_dim**-0.5

self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False)
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
MaybeLora = maybe_lora(args)
self.wq = MaybeLora(args.dim, args.n_heads * args.head_dim, bias=False)
self.wk = MaybeLora(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wv = MaybeLora(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wo = MaybeLora(args.n_heads * args.head_dim, args.dim, bias=False)

def forward(
self,
Expand Down Expand Up @@ -127,9 +139,10 @@ class FeedForward(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()

self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)
MaybeLora = maybe_lora(args)
self.w1 = MaybeLora(args.dim, args.hidden_dim, bias=False)
self.w2 = MaybeLora(args.hidden_dim, args.dim, bias=False)
self.w3 = MaybeLora(args.dim, args.hidden_dim, bias=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) # type: ignore
Expand Down Expand Up @@ -179,7 +192,7 @@ def forward(
return out


class Transformer(nn.Module):
class Transformer(nn.Module, LoRALoaderMixin):
def __init__(
self,
args: ModelArgs,
Expand Down

0 comments on commit cd06b0d

Please sign in to comment.