Skip to content

Commit

Permalink
fix: strip empy messages and only show relevant messages on fim
Browse files Browse the repository at this point in the history
  • Loading branch information
yrobla committed Jan 3, 2025
1 parent 8b95d7f commit 016949c
Showing 1 changed file with 53 additions and 28 deletions.
81 changes: 53 additions & 28 deletions src/codegate/dashboard/post_processing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import json
import re
from typing import List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union

import structlog

Expand All @@ -24,6 +24,8 @@
"you don't need ending punctuation.",
]

MEANINGFUL_MESSAGE_LENGTH = 50


async def _is_system_prompt(message: str) -> bool:
"""
Expand Down Expand Up @@ -90,7 +92,7 @@ async def parse_output(output_str: str) -> Tuple[Optional[str], Optional[str]]:
logger.warning(f"Error parsing output: {output_str}. {e}")
return None, None

def _parse_single_output(single_output: dict) -> str:
def _parse_single_output(single_output: dict) -> Tuple[Any, Any]:
single_chat_id = single_output.get("id")
single_output_message = ""
for choice in single_output.get("choices", []):
Expand Down Expand Up @@ -134,6 +136,10 @@ async def _get_question_answer(
The row contains the raw request and output strings from the pipeline.
"""
async with asyncio.TaskGroup() as tg:
if not row.request:
return None, None
if not row.output:
return None, None
request_task = tg.create_task(parse_request(row.request))
output_task = tg.create_task(parse_output(row.output))

Expand All @@ -150,11 +156,11 @@ async def _get_question_answer(
message_id=row.id,
)
if output_msg_str:
output_message = ChatMessage(
message=output_msg_str,
timestamp=row.output_timestamp,
message_id=row.output_id,
)
if row and row.output_timestamp and row.output_id:
output_message = ChatMessage(
message=output_msg_str,
timestamp=row.output_timestamp,
message_id=row.output_id,)
else:
output_message = None
chat_id = row.id
Expand All @@ -181,7 +187,7 @@ async def parse_get_prompt_with_output(
)


def parse_question_answer(input_text: str) -> str:
def _parse_question_answer_question(input_text: str) -> str:
# given a string, detect if we have a pattern of "Context: xxx \n\nQuery: xxx" and strip it
pattern = r'^Context:.*?\n\n\s*Query:\s*(.*)$'

Expand All @@ -190,11 +196,17 @@ def parse_question_answer(input_text: str) -> str:

# If a match is found, return the captured group after "Query:"
if match:
input_text = match.group(1)
return match.group(1)
else:
return input_text


def _parse_question_answer_answer(input_text: str) -> str:
# need to remove blank spaces and newlines
return input_text.strip()


async def match_conversations(
partial_conversations: List[Optional[PartialConversation]],
) -> List[Conversation]:
Expand Down Expand Up @@ -224,22 +236,34 @@ async def match_conversations(
for partial_conversation in sorted_convers:
# check if we have an answer, otherwise do not add it
if partial_conversation.question_answer.answer is not None:
first_partial_conversation = partial_conversation
partial_conversation.question_answer.question.message = parse_question_answer(
partial_conversation.question_answer.question.message)
questions_answers.append(partial_conversation.question_answer)
partial_conversation.question_answer.answer.message = _parse_question_answer_answer(partial_conversation.question_answer.answer.message)
if partial_conversation.question_answer.answer.message:
first_partial_conversation = partial_conversation
partial_conversation.question_answer.question.message = _parse_question_answer_question(
partial_conversation.question_answer.question.message)
questions_answers.append(partial_conversation.question_answer)

# only add conversation if we have some answers
if len(questions_answers) > 0 and first_partial_conversation is not None:
conversations.append(
Conversation(
question_answers=questions_answers,
provider=first_partial_conversation.provider,
type=first_partial_conversation.type,
chat_id=chat_id,
conversation_timestamp=sorted_convers[0].request_timestamp,
# if type is fim, we will only show the meaningful answers
can_show = True
if first_partial_conversation.type == 'fim':
# only show it if length of the answer is meaningful
can_show = False
for qa in questions_answers:
if len(qa.answer.message) > MEANINGFUL_MESSAGE_LENGTH:
can_show = True
break
if can_show:
conversations.append(
Conversation(
question_answers=questions_answers,
provider=first_partial_conversation.provider,
type=first_partial_conversation.type,
chat_id=chat_id,
conversation_timestamp=sorted_convers[0].request_timestamp,
)
)
)

return conversations

Expand Down Expand Up @@ -272,13 +296,14 @@ async def parse_row_alert_conversation(
if not question_answer or not chat_id:
return None

conversation = Conversation(
question_answers=[question_answer],
provider=row.provider,
type=row.type,
chat_id=chat_id or "chat-id-not-found",
conversation_timestamp=row.timestamp,
)
if row and row.type:
conversation = Conversation(
question_answers=[question_answer],
provider=row.provider,
type=row.type,
chat_id=chat_id or "chat-id-not-found",
conversation_timestamp=row.timestamp,
)
code_snippet = json.loads(row.code_snippet) if row.code_snippet else None
trigger_string = None
if row.trigger_string:
Expand All @@ -300,7 +325,7 @@ async def parse_row_alert_conversation(

async def parse_get_alert_conversation(
alerts_conversations: List[GetAlertsWithPromptAndOutputRow],
) -> List[AlertConversation]:
) -> List[AlertConversation | None]:
"""
Parse a list of rows from the get_alerts_with_prompt_and_output query and return a list of
AlertConversation
Expand Down

0 comments on commit 016949c

Please sign in to comment.