From c282eb78e735c176359a99e767d8e78f5fc8f6bd Mon Sep 17 00:00:00 2001 From: Muspi Merol Date: Sat, 9 Dec 2023 04:32:37 +0800 Subject: [PATCH] fix: inconsistency behavior parsing chat markup feat: add more unit tests for chat module --- python/promplate/prompt/chat.py | 2 +- python/tests/test_chat.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/python/promplate/prompt/chat.py b/python/promplate/prompt/chat.py index 28aa7ff..8950976 100644 --- a/python/promplate/prompt/chat.py +++ b/python/promplate/prompt/chat.py @@ -101,4 +101,4 @@ def parse_chat_markup(text: str) -> list[Message]: current_message["content"] = "\n".join(buffer) messages.append(current_message) - return messages or [{"role": "user", "content": text}] + return messages or [{"role": "user", "content": text.removesuffix("\n")}] diff --git a/python/tests/test_chat.py b/python/tests/test_chat.py index 0c6eb8a..f02007d 100644 --- a/python/tests/test_chat.py +++ b/python/tests/test_chat.py @@ -1,3 +1,5 @@ +from pytest import raises + from promplate.prompt.chat import A, S, U, parse_chat_markup @@ -16,3 +18,17 @@ def test_builder_names(): def test_chat_markup(): assert parse_chat_markup("hi") == [U > "hi"] assert parse_chat_markup("<| user name |>") == [U @ "name" > ""] + assert parse_chat_markup("123\n234\n\n345") == [U > "123\n234\n\n345"] + assert parse_chat_markup("<|user|>\n123\n<|assistant|>\n456") == [U > "123", A > "456"] + + +def test_auto_trim_trailing_blank_line(): + assert [U > "123"] == parse_chat_markup("123") + assert [U > "123"] == parse_chat_markup("123\n") + assert [U > "123"] == parse_chat_markup("<|user|>\n123") + assert [U > "123"] == parse_chat_markup("<|user|>\n123\n") + + +def test_immutable_constants(): + with raises(AssertionError): + A.content = "content"