Skip to content

Commit

Permalink
Merge pull request stanfordnlp#773 from stanfordnlp/curieo-org-initia…
Browse files Browse the repository at this point in the history
…l-groq-support

Curieo org initial groq support
  • Loading branch information
arnavsinghvi11 authored Apr 5, 2024
2 parents a12b362 + 969220b commit 37ee5dc
Show file tree
Hide file tree
Showing 8 changed files with 252 additions and 6 deletions.
51 changes: 51 additions & 0 deletions docs/api/language_model_clients/Groq.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
---
sidebar_position: 9
---

# dspy.GROQ

### Usage

```python
lm = dspy.GROQ(model='mixtral-8x7b-32768', api_key ="gsk_***" )
```

### Constructor

The constructor initializes the base class `LM` and verifies the provided arguments like the `api_key` for GROQ api retriver. The `kwargs` attribute is initialized with default values for relevant text generation parameters needed for communicating with the GPT API, such as `temperature`, `max_tokens`, `top_p`, `frequency_penalty`, `presence_penalty`, and `n`.

```python
class GroqLM(LM):
def __init__(
self,
api_key: str,
model: str = "mixtral-8x7b-32768",
**kwargs,
):
```



**Parameters:**
- `api_key` str: API provider authentication token. Defaults to None.
- `model` str: model name. Defaults to "mixtral-8x7b-32768' options: ['llama2-70b-4096', 'gemma-7b-it']
- `**kwargs`: Additional language model arguments to pass to the API provider.

### Methods

#### `def __call__(self, prompt: str, only_completed: bool = True, return_sorted: bool = False, **kwargs, ) -> list[dict[str, Any]]:`

Retrieves completions from GROQ by calling `request`.

Internally, the method handles the specifics of preparing the request prompt and corresponding payload to obtain the response.

After generation, the generated content look like `choice["message"]["content"]`.

**Parameters:**
- `prompt` (_str_): Prompt to send to OpenAI.
- `only_completed` (_bool_, _optional_): Flag to return only completed responses and ignore completion due to length. Defaults to True.
- `return_sorted` (_bool_, _optional_): Flag to sort the completion choices using the returned averaged log-probabilities. Defaults to False.
- `**kwargs`: Additional keyword arguments for completion request.

**Returns:**
- `List[Dict[str, Any]]`: List of completion choices.
2 changes: 2 additions & 0 deletions dsp/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
from .databricks import *
from .google import *
from .gpt3 import *
from .groq_client import *
from .hf import HFModel
from .hf_client import Anyscale, HFClientTGI, Together
from .mistral import *
from .ollama import *
from .pyserini import *
from .sbert import *
from .sentence_vectorizer import *

169 changes: 169 additions & 0 deletions dsp/modules/groq_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import logging
from typing import Any

import backoff

try:
import groq
from groq import Groq
groq_api_error = (groq.APIError, groq.RateLimitError)
except ImportError:
groq_api_error = (Exception)


import dsp
from dsp.modules.lm import LM

# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(message)s",
handlers=[logging.FileHandler("groq_usage.log")],
)



def backoff_hdlr(details):
"""Handler from https://pypi.org/project/backoff/"""
print(
"Backing off {wait:0.1f} seconds after {tries} tries "
"calling function {target} with kwargs "
"{kwargs}".format(**details),
)


class GroqLM(LM):
"""Wrapper around groq's API.
Args:
model (str, optional): groq supported LLM model to use. Defaults to "mixtral-8x7b-32768".
api_key (Optional[str], optional): API provider Authentication token. use Defaults to None.
**kwargs: Additional arguments to pass to the API provider.
"""

def __init__(
self,
api_key: str,
model: str = "mixtral-8x7b-32768",
**kwargs,
):
super().__init__(model)
self.provider = "groq"
if api_key:
self.api_key = api_key
self.client = Groq(api_key = api_key)
else:
raise ValueError("api_key is required for groq")


self.kwargs = {
"temperature": 0.0,
"max_tokens": 150,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"n": 1,
**kwargs,
}
models = self.client.models.list().data
if models is not None:
if model in [m.id for m in models]:
self.kwargs["model"] = model
self.history: list[dict[str, Any]] = []


def log_usage(self, response):
"""Log the total tokens from the Groq API response."""
usage_data = response.get("usage")
if usage_data:
total_tokens = usage_data.get("total_tokens")
logging.info(f"{total_tokens}")

def basic_request(self, prompt: str, **kwargs):
raw_kwargs = kwargs

kwargs = {**self.kwargs, **kwargs}

kwargs["messages"] = [{"role": "user", "content": prompt}]
response = self.chat_request(**kwargs)

history = {
"prompt": prompt,
"response": response.choices[0].message.content,
"kwargs": kwargs,
"raw_kwargs": raw_kwargs,
}

self.history.append(history)

return response

@backoff.on_exception(
backoff.expo,
groq_api_error,
max_time=1000,
on_backoff=backoff_hdlr,
)
def request(self, prompt: str, **kwargs):
"""Handles retreival of model completions whilst handling rate limiting and caching."""
if "model_type" in kwargs:
del kwargs["model_type"]

return self.basic_request(prompt, **kwargs)

def _get_choice_text(self, choice) -> str:
return choice.message.content

def chat_request(self, **kwargs):
"""Handles retreival of model completions whilst handling rate limiting and caching."""
response = self.client.chat.completions.create(**kwargs)
return response

def __call__(
self,
prompt: str,
only_completed: bool = True,
return_sorted: bool = False,
**kwargs,
) -> list[dict[str, Any]]:
"""Retrieves completions from model.
Args:
prompt (str): prompt to send to model
only_completed (bool, optional): return only completed responses and ignores completion due to length. Defaults to True.
return_sorted (bool, optional): sort the completion choices using the returned probabilities. Defaults to False.
Returns:
list[dict[str, Any]]: list of completion choices
"""

assert only_completed, "for now"
assert return_sorted is False, "for now"
response = self.request(prompt, **kwargs)

if dsp.settings.log_openai_usage:
self.log_usage(response)

choices = response.choices

completions = [self._get_choice_text(c) for c in choices]
if return_sorted and kwargs.get("n", 1) > 1:
scored_completions = []

for c in choices:
tokens, logprobs = (
c["logprobs"]["tokens"],
c["logprobs"]["token_logprobs"],
)

if "<|endoftext|>" in tokens:
index = tokens.index("<|endoftext|>") + 1
tokens, logprobs = tokens[:index], logprobs[:index]

avglog = sum(logprobs) / len(logprobs)
scored_completions.append((avglog, self._get_choice_text(c)))

scored_completions = sorted(scored_completions, reverse=True)
completions = [c for _, c in scored_completions]

return completions
2 changes: 1 addition & 1 deletion dsp/modules/hf_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,4 +435,4 @@ def _generate(self, prompt, **kwargs):

@CacheMemory.cache
def send_hfsglang_request_v00(arg, **kwargs):
return requests.post(arg, **kwargs)
return requests.post(arg, **kwargs)
8 changes: 5 additions & 3 deletions dsp/modules/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,11 @@ def inspect_history(self, n: int = 1, skip: int = 0):
if provider == "cohere":
text = choices
elif provider == "openai" or provider == "ollama":
text = " " + self._get_choice_text(choices[0]).strip()
elif provider == "clarifai":
text = choices
text = ' ' + self._get_choice_text(choices[0]).strip()
elif provider == "clarifai" or provider == "claude" :
text=choices
elif provider == "groq":
text = ' ' + choices
elif provider == "google":
text = choices[0].parts[0].text
elif provider == "mistral":
Expand Down
1 change: 1 addition & 0 deletions dspy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Pyserini = dsp.PyseriniRetriever
Clarifai = dsp.ClarifaiLLM
Google = dsp.Google
GROQ = dsp.GroqLM

HFClientTGI = dsp.HFClientTGI
HFClientVLLM = HFClientVLLM
Expand Down
23 changes: 21 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ docs = [
"autodoc_pydantic",
"sphinx-reredirects>=0.1.2",
"sphinx-automodapi==0.16.0",

]
dev = ["pytest>=6.2.5"]

Expand Down Expand Up @@ -108,6 +109,7 @@ sphinx_rtd_theme = { version = "*", optional = true }
autodoc_pydantic = { version = "*", optional = true }
sphinx-reredirects = { version = "^0.1.2", optional = true }
sphinx-automodapi = { version = "0.16.0", optional = true }
groq = {version = "^0.4.2", optional = true }
rich = "^13.7.1"
psycopg2 = {version = "^2.9.9", optional = true}
pgvector = {version = "^0.2.5", optional = true}
Expand Down

0 comments on commit 37ee5dc

Please sign in to comment.