Skip to content

Commit

Permalink
Refactor callbacks (#1583)
Browse files Browse the repository at this point in the history
* Unify Workflow and Verb callbacks interfaces

* Semver

* Fix storage class instantiation (#1582)

---------

Co-authored-by: Josh Bradley <[email protected]>
  • Loading branch information
natoverse and jgbradley1 authored Jan 6, 2025
1 parent cbb8f87 commit 7ec9ef0
Show file tree
Hide file tree
Showing 70 changed files with 193 additions and 367 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20250103231659816022.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Simplify callbacks model."
}
4 changes: 2 additions & 2 deletions docs/examples_notebooks/index_migration.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@
"outputs": [],
"source": [
"from graphrag.cache.factory import create_cache\n",
"from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks\n",
"from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks\n",
"from graphrag.index.flows.generate_text_embeddings import generate_text_embeddings\n",
"\n",
"# We only need to re-run the embeddings workflow, to ensure that embeddings for all required search fields are in place\n",
Expand All @@ -219,7 +219,7 @@
"config = workflow.config\n",
"text_embed = config.get(\"text_embed\", {})\n",
"embedded_fields = config.get(\"embedded_fields\", {})\n",
"callbacks = NoopVerbCallbacks()\n",
"callbacks = NoopWorkflowCallbacks()\n",
"cache = create_cache(pipeline_config.cache, PROJECT_DIRECTORY)\n",
"\n",
"await generate_text_embeddings(\n",
Expand Down
4 changes: 2 additions & 2 deletions graphrag/api/prompt_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from pydantic import PositiveInt, validate_call

from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.llm.load_llm import load_llm
from graphrag.logger.print_progress import PrintProgressLogger
Expand Down Expand Up @@ -99,7 +99,7 @@ async def generate_indexing_prompts(
"prompt_tuning",
config.llm,
cache=None,
callbacks=NoopVerbCallbacks(),
callbacks=NoopWorkflowCallbacks(),
)

if not domain:
Expand Down
6 changes: 3 additions & 3 deletions graphrag/callbacks/blob_workflow_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _write_log(self, log: dict[str, Any]):
# update the blob's block count
self._num_blocks += 1

def on_error(
def error(
self,
message: str,
cause: BaseException | None = None,
Expand All @@ -100,10 +100,10 @@ def on_error(
"details": details,
})

def on_warning(self, message: str, details: dict | None = None):
def warning(self, message: str, details: dict | None = None):
"""Report a warning."""
self._write_log({"type": "warning", "data": message, "details": details})

def on_log(self, message: str, details: dict | None = None):
def log(self, message: str, details: dict | None = None):
"""Report a generic log message."""
self._write_log({"type": "log", "data": message, "details": details})
6 changes: 3 additions & 3 deletions graphrag/callbacks/console_workflow_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
class ConsoleWorkflowCallbacks(NoopWorkflowCallbacks):
"""A logger that writes to a console."""

def on_error(
def error(
self,
message: str,
cause: BaseException | None = None,
Expand All @@ -19,11 +19,11 @@ def on_error(
"""Handle when an error occurs."""
print(message, str(cause), stack, details) # noqa T201

def on_warning(self, message: str, details: dict | None = None):
def warning(self, message: str, details: dict | None = None):
"""Handle when a warning occurs."""
_print_warning(message)

def on_log(self, message: str, details: dict | None = None):
def log(self, message: str, details: dict | None = None):
"""Handle when a log message is produced."""
print(message, details) # noqa T201

Expand Down
46 changes: 0 additions & 46 deletions graphrag/callbacks/delegating_verb_callbacks.py

This file was deleted.

6 changes: 3 additions & 3 deletions graphrag/callbacks/file_workflow_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, directory: str):
Path(directory) / "logs.json", "a", encoding="utf-8", errors="strict"
)

def on_error(
def error(
self,
message: str,
cause: BaseException | None = None,
Expand All @@ -50,7 +50,7 @@ def on_error(
message = f"{message} details={details}"
log.info(message)

def on_warning(self, message: str, details: dict | None = None):
def warning(self, message: str, details: dict | None = None):
"""Handle when a warning occurs."""
self._out_stream.write(
json.dumps(
Expand All @@ -61,7 +61,7 @@ def on_warning(self, message: str, details: dict | None = None):
)
_print_warning(message)

def on_log(self, message: str, details: dict | None = None):
def log(self, message: str, details: dict | None = None):
"""Handle when a log message is produced."""
self._out_stream.write(
json.dumps(
Expand Down
35 changes: 0 additions & 35 deletions graphrag/callbacks/noop_verb_callbacks.py

This file was deleted.

23 changes: 6 additions & 17 deletions graphrag/callbacks/noop_workflow_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,23 @@

"""A no-op implementation of WorkflowCallbacks."""

from typing import Any

from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.logger.progress import Progress


class NoopWorkflowCallbacks(WorkflowCallbacks):
"""A no-op implementation of WorkflowCallbacks."""

def on_workflow_start(self, name: str, instance: object) -> None:
def workflow_start(self, name: str, instance: object) -> None:
"""Execute this callback when a workflow starts."""

def on_workflow_end(self, name: str, instance: object) -> None:
def workflow_end(self, name: str, instance: object) -> None:
"""Execute this callback when a workflow ends."""

def on_step_start(self, step_name: str) -> None:
"""Execute this callback every time a step starts."""

def on_step_end(self, step_name: str, result: Any) -> None:
"""Execute this callback every time a step ends."""

def on_step_progress(self, step_name: str, progress: Progress) -> None:
def progress(self, progress: Progress) -> None:
"""Handle when progress occurs."""

def on_error(
def error(
self,
message: str,
cause: BaseException | None = None,
Expand All @@ -36,11 +28,8 @@ def on_error(
) -> None:
"""Handle when an error occurs."""

def on_warning(self, message: str, details: dict | None = None) -> None:
def warning(self, message: str, details: dict | None = None) -> None:
"""Handle when a warning occurs."""

def on_log(self, message: str, details: dict | None = None) -> None:
def log(self, message: str, details: dict | None = None) -> None:
"""Handle when a log message occurs."""

def on_measure(self, name: str, value: float, details: dict | None = None) -> None:
"""Handle when a measurement occurs."""
17 changes: 3 additions & 14 deletions graphrag/callbacks/progress_workflow_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

"""A workflow callback manager that emits updates."""

from typing import Any

from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.logger.base import ProgressLogger
from graphrag.logger.progress import Progress
Expand All @@ -31,23 +29,14 @@ def _push(self, name: str) -> None:
def _latest(self) -> ProgressLogger:
return self._progress_stack[-1]

def on_workflow_start(self, name: str, instance: object) -> None:
def workflow_start(self, name: str, instance: object) -> None:
"""Execute this callback when a workflow starts."""
self._push(name)

def on_workflow_end(self, name: str, instance: object) -> None:
def workflow_end(self, name: str, instance: object) -> None:
"""Execute this callback when a workflow ends."""
self._pop()

def on_step_start(self, step_name: str) -> None:
"""Execute this callback every time a step starts."""
self._push(f"Step {step_name}")
self._latest(Progress(percent=0))

def on_step_end(self, step_name: str, result: Any) -> None:
"""Execute this callback every time a step ends."""
self._pop()

def on_step_progress(self, step_name: str, progress: Progress) -> None:
def progress(self, progress: Progress) -> None:
"""Handle when progress occurs."""
self._latest(progress)
38 changes: 0 additions & 38 deletions graphrag/callbacks/verb_callbacks.py

This file was deleted.

26 changes: 7 additions & 19 deletions graphrag/callbacks/workflow_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

"""Collection of callbacks that can be used to monitor the workflow execution."""

from typing import Any, Protocol
from typing import Protocol

from graphrag.logger.progress import Progress

Expand All @@ -15,27 +15,19 @@ class WorkflowCallbacks(Protocol):
This base class is a "noop" implementation so that clients may implement just the callbacks they need.
"""

def on_workflow_start(self, name: str, instance: object) -> None:
def workflow_start(self, name: str, instance: object) -> None:
"""Execute this callback when a workflow starts."""
...

def on_workflow_end(self, name: str, instance: object) -> None:
def workflow_end(self, name: str, instance: object) -> None:
"""Execute this callback when a workflow ends."""
...

def on_step_start(self, step_name: str) -> None:
"""Execute this callback every time a step starts."""
...

def on_step_end(self, step_name: str, result: Any) -> None:
"""Execute this callback every time a step ends."""
...

def on_step_progress(self, step_name: str, progress: Progress) -> None:
def progress(self, progress: Progress) -> None:
"""Handle when progress occurs."""
...

def on_error(
def error(
self,
message: str,
cause: BaseException | None = None,
Expand All @@ -45,14 +37,10 @@ def on_error(
"""Handle when an error occurs."""
...

def on_warning(self, message: str, details: dict | None = None) -> None:
def warning(self, message: str, details: dict | None = None) -> None:
"""Handle when a warning occurs."""
...

def on_log(self, message: str, details: dict | None = None) -> None:
def log(self, message: str, details: dict | None = None) -> None:
"""Handle when a log message occurs."""
...

def on_measure(self, name: str, value: float, details: dict | None = None) -> None:
"""Handle when a measurement occurs."""
...
Loading

0 comments on commit 7ec9ef0

Please sign in to comment.