forked from stanfordnlp/dspy
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request stanfordnlp#773 from stanfordnlp/curieo-org-initia…
…l-groq-support Curieo org initial groq support
- Loading branch information
Showing
8 changed files
with
252 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
--- | ||
sidebar_position: 9 | ||
--- | ||
|
||
# dspy.GROQ | ||
|
||
### Usage | ||
|
||
```python | ||
lm = dspy.GROQ(model='mixtral-8x7b-32768', api_key ="gsk_***" ) | ||
``` | ||
|
||
### Constructor | ||
|
||
The constructor initializes the base class `LM` and verifies the provided arguments like the `api_key` for GROQ api retriver. The `kwargs` attribute is initialized with default values for relevant text generation parameters needed for communicating with the GPT API, such as `temperature`, `max_tokens`, `top_p`, `frequency_penalty`, `presence_penalty`, and `n`. | ||
|
||
```python | ||
class GroqLM(LM): | ||
def __init__( | ||
self, | ||
api_key: str, | ||
model: str = "mixtral-8x7b-32768", | ||
**kwargs, | ||
): | ||
``` | ||
|
||
|
||
|
||
**Parameters:** | ||
- `api_key` str: API provider authentication token. Defaults to None. | ||
- `model` str: model name. Defaults to "mixtral-8x7b-32768' options: ['llama2-70b-4096', 'gemma-7b-it'] | ||
- `**kwargs`: Additional language model arguments to pass to the API provider. | ||
|
||
### Methods | ||
|
||
#### `def __call__(self, prompt: str, only_completed: bool = True, return_sorted: bool = False, **kwargs, ) -> list[dict[str, Any]]:` | ||
|
||
Retrieves completions from GROQ by calling `request`. | ||
|
||
Internally, the method handles the specifics of preparing the request prompt and corresponding payload to obtain the response. | ||
|
||
After generation, the generated content look like `choice["message"]["content"]`. | ||
|
||
**Parameters:** | ||
- `prompt` (_str_): Prompt to send to OpenAI. | ||
- `only_completed` (_bool_, _optional_): Flag to return only completed responses and ignore completion due to length. Defaults to True. | ||
- `return_sorted` (_bool_, _optional_): Flag to sort the completion choices using the returned averaged log-probabilities. Defaults to False. | ||
- `**kwargs`: Additional keyword arguments for completion request. | ||
|
||
**Returns:** | ||
- `List[Dict[str, Any]]`: List of completion choices. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
import logging | ||
from typing import Any | ||
|
||
import backoff | ||
|
||
try: | ||
import groq | ||
from groq import Groq | ||
groq_api_error = (groq.APIError, groq.RateLimitError) | ||
except ImportError: | ||
groq_api_error = (Exception) | ||
|
||
|
||
import dsp | ||
from dsp.modules.lm import LM | ||
|
||
# Configure logging | ||
logging.basicConfig( | ||
level=logging.INFO, | ||
format="%(message)s", | ||
handlers=[logging.FileHandler("groq_usage.log")], | ||
) | ||
|
||
|
||
|
||
def backoff_hdlr(details): | ||
"""Handler from https://pypi.org/project/backoff/""" | ||
print( | ||
"Backing off {wait:0.1f} seconds after {tries} tries " | ||
"calling function {target} with kwargs " | ||
"{kwargs}".format(**details), | ||
) | ||
|
||
|
||
class GroqLM(LM): | ||
"""Wrapper around groq's API. | ||
Args: | ||
model (str, optional): groq supported LLM model to use. Defaults to "mixtral-8x7b-32768". | ||
api_key (Optional[str], optional): API provider Authentication token. use Defaults to None. | ||
**kwargs: Additional arguments to pass to the API provider. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
api_key: str, | ||
model: str = "mixtral-8x7b-32768", | ||
**kwargs, | ||
): | ||
super().__init__(model) | ||
self.provider = "groq" | ||
if api_key: | ||
self.api_key = api_key | ||
self.client = Groq(api_key = api_key) | ||
else: | ||
raise ValueError("api_key is required for groq") | ||
|
||
|
||
self.kwargs = { | ||
"temperature": 0.0, | ||
"max_tokens": 150, | ||
"top_p": 1, | ||
"frequency_penalty": 0, | ||
"presence_penalty": 0, | ||
"n": 1, | ||
**kwargs, | ||
} | ||
models = self.client.models.list().data | ||
if models is not None: | ||
if model in [m.id for m in models]: | ||
self.kwargs["model"] = model | ||
self.history: list[dict[str, Any]] = [] | ||
|
||
|
||
def log_usage(self, response): | ||
"""Log the total tokens from the Groq API response.""" | ||
usage_data = response.get("usage") | ||
if usage_data: | ||
total_tokens = usage_data.get("total_tokens") | ||
logging.info(f"{total_tokens}") | ||
|
||
def basic_request(self, prompt: str, **kwargs): | ||
raw_kwargs = kwargs | ||
|
||
kwargs = {**self.kwargs, **kwargs} | ||
|
||
kwargs["messages"] = [{"role": "user", "content": prompt}] | ||
response = self.chat_request(**kwargs) | ||
|
||
history = { | ||
"prompt": prompt, | ||
"response": response.choices[0].message.content, | ||
"kwargs": kwargs, | ||
"raw_kwargs": raw_kwargs, | ||
} | ||
|
||
self.history.append(history) | ||
|
||
return response | ||
|
||
@backoff.on_exception( | ||
backoff.expo, | ||
groq_api_error, | ||
max_time=1000, | ||
on_backoff=backoff_hdlr, | ||
) | ||
def request(self, prompt: str, **kwargs): | ||
"""Handles retreival of model completions whilst handling rate limiting and caching.""" | ||
if "model_type" in kwargs: | ||
del kwargs["model_type"] | ||
|
||
return self.basic_request(prompt, **kwargs) | ||
|
||
def _get_choice_text(self, choice) -> str: | ||
return choice.message.content | ||
|
||
def chat_request(self, **kwargs): | ||
"""Handles retreival of model completions whilst handling rate limiting and caching.""" | ||
response = self.client.chat.completions.create(**kwargs) | ||
return response | ||
|
||
def __call__( | ||
self, | ||
prompt: str, | ||
only_completed: bool = True, | ||
return_sorted: bool = False, | ||
**kwargs, | ||
) -> list[dict[str, Any]]: | ||
"""Retrieves completions from model. | ||
Args: | ||
prompt (str): prompt to send to model | ||
only_completed (bool, optional): return only completed responses and ignores completion due to length. Defaults to True. | ||
return_sorted (bool, optional): sort the completion choices using the returned probabilities. Defaults to False. | ||
Returns: | ||
list[dict[str, Any]]: list of completion choices | ||
""" | ||
|
||
assert only_completed, "for now" | ||
assert return_sorted is False, "for now" | ||
response = self.request(prompt, **kwargs) | ||
|
||
if dsp.settings.log_openai_usage: | ||
self.log_usage(response) | ||
|
||
choices = response.choices | ||
|
||
completions = [self._get_choice_text(c) for c in choices] | ||
if return_sorted and kwargs.get("n", 1) > 1: | ||
scored_completions = [] | ||
|
||
for c in choices: | ||
tokens, logprobs = ( | ||
c["logprobs"]["tokens"], | ||
c["logprobs"]["token_logprobs"], | ||
) | ||
|
||
if "<|endoftext|>" in tokens: | ||
index = tokens.index("<|endoftext|>") + 1 | ||
tokens, logprobs = tokens[:index], logprobs[:index] | ||
|
||
avglog = sum(logprobs) / len(logprobs) | ||
scored_completions.append((avglog, self._get_choice_text(c))) | ||
|
||
scored_completions = sorted(scored_completions, reverse=True) | ||
completions = [c for _, c in scored_completions] | ||
|
||
return completions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters