Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test and unify text splitter functionality #1547

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20241220191518597340.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "unit tests for text_splitting"
}
45 changes: 5 additions & 40 deletions graphrag/index/operations/chunk_text/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@

import graphrag.config.defaults as defs
from graphrag.index.operations.chunk_text.typing import TextChunk
from graphrag.index.text_splitting.text_splitting import Tokenizer
from graphrag.index.text_splitting.text_splitting import (
Tokenizer,
split_multiple_texts_on_tokens,
)


def run_tokens(
Expand All @@ -32,7 +35,7 @@ def encode(text: str) -> list[int]:
def decode(tokens: list[int]) -> str:
return enc.decode(tokens)

return _split_text_on_tokens(
return split_multiple_texts_on_tokens(
input,
Tokenizer(
chunk_overlap=chunk_overlap,
Expand All @@ -44,44 +47,6 @@ def decode(tokens: list[int]) -> str:
)


# Adapted from - https://github.com/langchain-ai/langchain/blob/77b359edf5df0d37ef0d539f678cf64f5557cb54/libs/langchain/langchain/text_splitter.py#L471
# So we could have better control over the chunking process
def _split_text_on_tokens(
texts: list[str], enc: Tokenizer, tick: ProgressTicker
) -> list[TextChunk]:
"""Split incoming text and return chunks."""
result = []
mapped_ids = []

for source_doc_idx, text in enumerate(texts):
encoded = enc.encode(text)
tick(1)
mapped_ids.append((source_doc_idx, encoded))

input_ids: list[tuple[int, int]] = [
(source_doc_idx, id) for source_doc_idx, ids in mapped_ids for id in ids
]

start_idx = 0
cur_idx = min(start_idx + enc.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
while start_idx < len(input_ids):
chunk_text = enc.decode([id for _, id in chunk_ids])
doc_indices = list({doc_idx for doc_idx, _ in chunk_ids})
result.append(
TextChunk(
text_chunk=chunk_text,
source_doc_indices=doc_indices,
n_tokens=len(chunk_ids),
)
)
start_idx += enc.tokens_per_chunk - enc.chunk_overlap
cur_idx = min(start_idx + enc.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]

return result


def run_sentences(
input: list[str], _args: dict[str, Any], tick: ProgressTicker
) -> Iterable[TextChunk]:
Expand Down
138 changes: 48 additions & 90 deletions graphrag/index/text_splitting/text_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,17 @@

"""A module containing the 'Tokenizer', 'TextSplitter', 'NoopTextSplitter' and 'TokenTextSplitter' models."""

import json
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable, Collection, Iterable
from dataclasses import dataclass
from enum import Enum
from typing import Any, Literal, cast

import pandas as pd
import tiktoken
from datashaper import ProgressTicker

from graphrag.index.utils.tokens import num_tokens_from_string
from graphrag.index.operations.chunk_text.typing import TextChunk

EncodedText = list[int]
DecodeFn = Callable[[EncodedText], str]
Expand Down Expand Up @@ -122,10 +121,10 @@ def num_tokens(self, text: str) -> int:

def split_text(self, text: str | list[str]) -> list[str]:
"""Split text method."""
if cast("bool", pd.isna(text)) or text == "":
return []
if isinstance(text, list):
text = " ".join(text)
elif cast("bool", pd.isna(text)) or text == "":
return []
if not isinstance(text, str):
msg = f"Attempting to split a non-string value, actual is {type(text)}"
raise TypeError(msg)
Expand All @@ -137,108 +136,67 @@ def split_text(self, text: str | list[str]) -> list[str]:
encode=lambda text: self.encode(text),
)

return split_text_on_tokens(text=text, tokenizer=tokenizer)


class TextListSplitterType(str, Enum):
"""Enum for the type of the TextListSplitter."""
return split_single_text_on_tokens(text=text, tokenizer=tokenizer)

DELIMITED_STRING = "delimited_string"
JSON = "json"


class TextListSplitter(TextSplitter):
"""Text list splitter class definition."""

def __init__(
self,
chunk_size: int,
splitter_type: TextListSplitterType = TextListSplitterType.JSON,
input_delimiter: str | None = None,
output_delimiter: str | None = None,
model_name: str | None = None,
encoding_name: str | None = None,
):
"""Initialize the TextListSplitter with a chunk size."""
# Set the chunk overlap to 0 as we use full strings
super().__init__(chunk_size, chunk_overlap=0)
self._type = splitter_type
self._input_delimiter = input_delimiter
self._output_delimiter = output_delimiter or "\n"
self._length_function = lambda x: num_tokens_from_string(
x, model=model_name, encoding_name=encoding_name
)

def split_text(self, text: str | list[str]) -> Iterable[str]:
"""Split a string list into a list of strings for a given chunk size."""
if not text:
return []

result: list[str] = []
current_chunk: list[str] = []
def split_text_on_tokens(
texts: str | list[str], tokenizer: Tokenizer, tick=None
) -> list[str] | list[TextChunk]:
"""Handle both single text and list of texts."""
if isinstance(texts, str):
return split_single_text_on_tokens(texts, tokenizer)

# Add the brackets
current_length: int = self._length_function("[]")
return split_multiple_texts_on_tokens(texts, tokenizer, tick)

# Input should be a string list joined by a delimiter
string_list = self._load_text_list(text)

if len(string_list) == 1:
return string_list

for item in string_list:
# Count the length of the item and add comma
item_length = self._length_function(f"{item},")
def split_single_text_on_tokens(text: str, tokenizer: Tokenizer) -> list[str]:
"""Split a single text and return chunks using the tokenizer."""
result = []
input_ids = tokenizer.encode(text)

if current_length + item_length > self._chunk_size:
if current_chunk and len(current_chunk) > 0:
# Add the current chunk to the result
self._append_to_result(result, current_chunk)
start_idx = 0
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]

# Start a new chunk
current_chunk = [item]
# Add 2 for the brackets
current_length = item_length
else:
# Add the item to the current chunk
current_chunk.append(item)
# Add 1 for the comma
current_length += item_length
while start_idx < len(input_ids):
chunk_text = tokenizer.decode(list(chunk_ids))
result.append(chunk_text) # Append chunked text as string
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]

# Add the last chunk to the result
self._append_to_result(result, current_chunk)
return result

return result

def _load_text_list(self, text: str | list[str]):
"""Load the text list based on the type."""
if isinstance(text, list):
string_list = text
elif self._type == TextListSplitterType.JSON:
string_list = json.loads(text)
else:
string_list = text.split(self._input_delimiter)
return string_list
# Adapted from - https://github.com/langchain-ai/langchain/blob/77b359edf5df0d37ef0d539f678cf64f5557cb54/libs/langchain/langchain/text_splitter.py#L471
# So we could have better control over the chunking process
def split_multiple_texts_on_tokens(
texts: list[str], tokenizer: Tokenizer, tick: ProgressTicker | None = None
) -> list[TextChunk]:
"""Split multiple texts and return chunks with metadata using the tokenizer."""
result = []
mapped_ids = []

def _append_to_result(self, chunk_list: list[str], new_chunk: list[str]):
"""Append the current chunk to the result."""
if new_chunk and len(new_chunk) > 0:
if self._type == TextListSplitterType.JSON:
chunk_list.append(json.dumps(new_chunk, ensure_ascii=False))
else:
chunk_list.append(self._output_delimiter.join(new_chunk))
for source_doc_idx, text in enumerate(texts):
encoded = tokenizer.encode(text)
if tick:
tick(1) # Track progress if tick callback is provided
mapped_ids.append((source_doc_idx, encoded))

input_ids = [
(source_doc_idx, id) for source_doc_idx, ids in mapped_ids for id in ids
]

def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]:
"""Split incoming text and return chunks using tokenizer."""
splits: list[str] = []
input_ids = tokenizer.encode(text)
start_idx = 0
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]

while start_idx < len(input_ids):
splits.append(tokenizer.decode(chunk_ids))
chunk_text = tokenizer.decode([id for _, id in chunk_ids])
doc_indices = list({doc_idx for doc_idx, _ in chunk_ids})
result.append(TextChunk(chunk_text, doc_indices, len(chunk_ids)))
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
return splits

return result
2 changes: 2 additions & 0 deletions tests/unit/indexing/text_splitting/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
Loading
Loading