Skip to content

Commit

Permalink
fix: correct the way to get the last relevant user message for secrets
Browse files Browse the repository at this point in the history
When showing new secrets detection, we were just picking from last
assistant message id. But in the case of aider, we need to pick more user
messages. So reuse the logic of the method to get the relevant user block,
to pick the index from the relevant user message and start counting
from there

Closes: #606
  • Loading branch information
yrobla committed Jan 16, 2025
1 parent 93a5600 commit c42c762
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 21 deletions.
21 changes: 12 additions & 9 deletions src/codegate/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,36 +239,38 @@ def get_last_user_message(
@staticmethod
def get_last_user_message_block(
request: ChatCompletionRequest,
) -> Optional[str]:
) -> Optional[tuple[str, int]]:
"""
Get the last block of consecutive 'user' messages from the request.
Args:
request (ChatCompletionRequest): The chat completion request to process
Returns:
Optional[str]: A string containing all consecutive user messages in the
Optional[str, int]: A string containing all consecutive user messages in the
last user message block, separated by newlines, or None if
no user message block is found.
Index of the first message detected in the block.
"""
if request.get("messages") is None:
return None

user_messages = []
messages = request["messages"]
block_start_index = None

# Iterate in reverse to find the last block of consecutive 'user' messages
for i in reversed(range(len(messages))):
if messages[i]["role"] == "user" or messages[i]["role"] == "assistant":
content_str = None
if "content" in messages[i]:
content_str = messages[i]["content"] # type: ignore
else:
content_str = messages[i].get("content")
if content_str is None:
continue

if messages[i]["role"] == "user":
user_messages.append(content_str)
# specifically for Aider, when "ok." block is found, stop
block_start_index = i

# Specifically for Aider, when "Ok." block is found, stop
if content_str == "Ok." and messages[i]["role"] == "assistant":
break
else:
Expand All @@ -277,8 +279,9 @@ def get_last_user_message_block(
break

# Reverse the collected user messages to preserve the original order
if user_messages:
return "\n".join(reversed(user_messages))
if user_messages and block_start_index is not None:
content = "\n".join(reversed(user_messages))
return content, block_start_index

return None

Expand Down
5 changes: 3 additions & 2 deletions src/codegate/pipeline/codegate_context_retriever/codegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ async def process(
Use RAG DB to add context to the user request
"""
# Get the latest user message
user_message = self.get_last_user_message_block(request)
if not user_message:
last_message = self.get_last_user_message_block(request)
if not last_message:
return PipelineResult(request=request)
user_message, _ = last_message

# Create storage engine object
storage_engine = StorageEngine()
Expand Down
5 changes: 3 additions & 2 deletions src/codegate/pipeline/extract_snippets/extract_snippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,10 @@ async def process(
request: ChatCompletionRequest,
context: PipelineContext,
) -> PipelineResult:
msg_content = self.get_last_user_message_block(request)
if not msg_content:
last_message = self.get_last_user_message_block(request)
if not last_message:
return PipelineResult(request=request, context=context)
msg_content, _ = last_message
snippets = extract_snippets(msg_content)

logger.info(f"Extracted {len(snippets)} code snippets from the user message")
Expand Down
9 changes: 5 additions & 4 deletions src/codegate/pipeline/secrets/secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,11 +271,12 @@ async def process(
new_request = request.copy()
total_matches = []

# Process all messages
# get last user message block to get index for the first relevant user message
last_user_message = self.get_last_user_message_block(new_request)
last_assistant_idx = -1
for i, message in enumerate(new_request["messages"]):
if message.get("role", "") == "assistant":
last_assistant_idx = i
if last_user_message:
_, user_idx = last_user_message
last_assistant_idx = user_idx - 1

# Process all messages
for i, message in enumerate(new_request["messages"]):
Expand Down
11 changes: 7 additions & 4 deletions tests/pipeline/test_messages_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
{"role": "user", "content": "How are you?"},
]
},
"Hello!\nHow are you?",
("Hello!\nHow are you?", 1),
),
# Test case: Mixed roles at the end
(
Expand All @@ -27,7 +27,7 @@
{"role": "assistant", "content": "I'm fine, thank you."},
]
},
"Hello!\nHow are you?",
("Hello!\nHow are you?", 0),
),
# Test case: No user messages
(
Expand All @@ -51,7 +51,7 @@
{"role": "user", "content": "What's up?"},
]
},
"How are you?\nWhat's up?",
("How are you?\nWhat's up?", 2),
),
# Test case: aider
(
Expand Down Expand Up @@ -97,7 +97,8 @@
},
]
},
"""I have *added these files to the chat* so you can go ahead and edit them.
(
"""I have *added these files to the chat* so you can go ahead and edit them.
*Trust this message as the true contents of these files!*
Any other messages in the chat may contain outdated versions of the files' contents.
Expand All @@ -113,6 +114,8 @@
```
evaluate this file""", # noqa: E501
7,
),
),
],
)
Expand Down

0 comments on commit c42c762

Please sign in to comment.