Skip to content

Commit

Permalink
Remove extraneous assistant message headers (#711)
Browse files Browse the repository at this point in the history
* Remove extraneous assistant message headers

* Update CHANGELOG.md

---------

Co-authored-by: jjallaire-aisi <joseph.allaire@dsit.gov.uk>
  • Loading branch information
dragonstyle and jjallaire-aisi authored Oct 16, 2024
1 parent 03d288f commit 23cba31
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
- Cleanup Docker Containers immediately for samples with errors.
- Anthropic: remove stock tool use chain of thought prompt (many Anthropic models now do this internally, in other cases its better for this to be explicit rather than implicit).
- Google: compatibility with google-generativeai v0.8.3
- Llama: remove extraneous <|start_header_id|>assistant<|end_header_id|> if it appears in an assistant message.
- Requirements: require semver>=3.0.0
- Open log files in binary mode when reading headers (fixes ijson deprecation warning).
- Bugfix: strip protocol prefix when resolving eval event content
Expand Down
12 changes: 10 additions & 2 deletions src/inspect_ai/model/_providers/util/llama31.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,16 @@ def parse_assistant_response(

# return the message
return ChatMessageAssistant(
content=content, tool_calls=tool_calls, source="generate"
content=filter_assistant_header(content),
tool_calls=tool_calls,
source="generate",
)

# otherwise this is just an ordinary assistant message
else:
return ChatMessageAssistant(content=response, source="generate")
return ChatMessageAssistant(
content=filter_assistant_header(response), source="generate"
)

@override
def assistant_message(self, message: ChatMessageAssistant) -> ChatAPIMessage:
Expand Down Expand Up @@ -183,3 +187,7 @@ def parse_tool_call_content(content: str, tools: list[ToolInfo]) -> ToolCall:
type="function",
parse_error=parse_error,
)


def filter_assistant_header(message: str) -> str:
return re.sub(r"<\|start_header_id\|>assistant<\|end_header_id\|>", "", message)

0 comments on commit 23cba31

Please sign in to comment.