Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: correct the way to get the last relevant user message for secrets #608

Merged
merged 3 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion signatures.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
---
- Amazon:
- Access Key: (?:A3T[A-Z0-9]|AKIA|AGPA|AIDA|AROA|AIPA|ANPA|ANVA|ASIA|ABIA|ACCA)[A-Z0-9]{16}
- Secret Access Key: (?<![A-Za-z0-9/+])[A-Za-z0-9+=][A-Za-z0-9/+=]{38}[A-Za-z0-9+=](?![A-Za-z0-9/+=])
# - Cognito User Pool ID: (?i)us-[a-z]{2,}-[a-z]{4,}-\d{1,}
- RDS Password: (?i)(rds\-master\-password|db\-password)
- SNS Confirmation URL: (?i)https:\/\/sns\.[a-z0-9-]+\.amazonaws\.com\/?Action=ConfirmSubscription&Token=[a-zA-Z0-9-=_]+
Expand Down
21 changes: 12 additions & 9 deletions src/codegate/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,36 +239,38 @@ def get_last_user_message(
@staticmethod
def get_last_user_message_block(
request: ChatCompletionRequest,
) -> Optional[str]:
) -> Optional[tuple[str, int]]:
"""
Get the last block of consecutive 'user' messages from the request.

Args:
request (ChatCompletionRequest): The chat completion request to process

Returns:
Optional[str]: A string containing all consecutive user messages in the
Optional[str, int]: A string containing all consecutive user messages in the
last user message block, separated by newlines, or None if
no user message block is found.
Index of the first message detected in the block.
"""
if request.get("messages") is None:
return None

user_messages = []
messages = request["messages"]
block_start_index = None

# Iterate in reverse to find the last block of consecutive 'user' messages
for i in reversed(range(len(messages))):
if messages[i]["role"] == "user" or messages[i]["role"] == "assistant":
content_str = None
if "content" in messages[i]:
content_str = messages[i]["content"] # type: ignore
else:
content_str = messages[i].get("content")
if content_str is None:
continue

if messages[i]["role"] == "user":
user_messages.append(content_str)
# specifically for Aider, when "ok." block is found, stop
block_start_index = i

# Specifically for Aider, when "Ok." block is found, stop
if content_str == "Ok." and messages[i]["role"] == "assistant":
break
else:
Expand All @@ -277,8 +279,9 @@ def get_last_user_message_block(
break

# Reverse the collected user messages to preserve the original order
if user_messages:
return "\n".join(reversed(user_messages))
if user_messages and block_start_index is not None:
content = "\n".join(reversed(user_messages))
return content, block_start_index

return None

Expand Down
5 changes: 3 additions & 2 deletions src/codegate/pipeline/codegate_context_retriever/codegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ async def process(
Use RAG DB to add context to the user request
"""
# Get the latest user message
user_message = self.get_last_user_message_block(request)
if not user_message:
last_message = self.get_last_user_message_block(request)
if not last_message:
return PipelineResult(request=request)
user_message, _ = last_message

# Create storage engine object
storage_engine = StorageEngine()
Expand Down
5 changes: 3 additions & 2 deletions src/codegate/pipeline/extract_snippets/extract_snippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,10 @@ async def process(
request: ChatCompletionRequest,
context: PipelineContext,
) -> PipelineResult:
msg_content = self.get_last_user_message_block(request)
if not msg_content:
last_message = self.get_last_user_message_block(request)
if not last_message:
return PipelineResult(request=request, context=context)
msg_content, _ = last_message
snippets = extract_snippets(msg_content)

logger.info(f"Extracted {len(snippets)} code snippets from the user message")
Expand Down
17 changes: 10 additions & 7 deletions src/codegate/pipeline/secrets/secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,11 +271,12 @@ async def process(
new_request = request.copy()
total_matches = []

# Process all messages
# get last user message block to get index for the first relevant user message
last_user_message = self.get_last_user_message_block(new_request)
last_assistant_idx = -1
for i, message in enumerate(new_request["messages"]):
if message.get("role", "") == "assistant":
last_assistant_idx = i
if last_user_message:
_, user_idx = last_user_message
last_assistant_idx = user_idx - 1

# Process all messages
for i, message in enumerate(new_request["messages"]):
Expand Down Expand Up @@ -312,8 +313,8 @@ class SecretUnredactionStep(OutputPipelineStep):
"""Pipeline step that unredacts protected content in the stream"""

def __init__(self):
self.redacted_pattern = re.compile(r"REDACTED<\$([^>]+)>")
self.marker_start = "REDACTED<$"
self.redacted_pattern = re.compile(r"REDACTED<(\$?[^>]+)>")
self.marker_start = "REDACTED<"
self.marker_end = ">"

@property
Expand Down Expand Up @@ -365,6 +366,8 @@ async def process_chunk(
if match:
# Found a complete marker, process it
encrypted_value = match.group(1)
if encrypted_value.startswith('$'):
encrypted_value = encrypted_value[1:]
original_value = input_context.sensitive.manager.get_original_value(
encrypted_value,
input_context.sensitive.session_id,
Expand Down Expand Up @@ -399,7 +402,7 @@ async def process_chunk(
return []

if self._is_partial_marker_prefix(buffered_content):
context.prefix_buffer += buffered_content
context.prefix_buffer = buffered_content
return []

# No markers or partial markers, let pipeline handle the chunk normally
Expand Down
11 changes: 7 additions & 4 deletions tests/pipeline/test_messages_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
{"role": "user", "content": "How are you?"},
]
},
"Hello!\nHow are you?",
("Hello!\nHow are you?", 1),
),
# Test case: Mixed roles at the end
(
Expand All @@ -27,7 +27,7 @@
{"role": "assistant", "content": "I'm fine, thank you."},
]
},
"Hello!\nHow are you?",
("Hello!\nHow are you?", 0),
),
# Test case: No user messages
(
Expand All @@ -51,7 +51,7 @@
{"role": "user", "content": "What's up?"},
]
},
"How are you?\nWhat's up?",
("How are you?\nWhat's up?", 2),
),
# Test case: aider
(
Expand Down Expand Up @@ -97,7 +97,8 @@
},
]
},
"""I have *added these files to the chat* so you can go ahead and edit them.
(
"""I have *added these files to the chat* so you can go ahead and edit them.

*Trust this message as the true contents of these files!*
Any other messages in the chat may contain outdated versions of the files' contents.
Expand All @@ -113,6 +114,8 @@
```

evaluate this file""", # noqa: E501
7,
),
),
],
)
Expand Down
Loading