From 95a9d97c5101579fc41b0954e1beab2663027a9a Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Thu, 19 Oct 2023 20:18:39 -0500 Subject: [PATCH] Add another missing walk_fsm condition The full-match option handling was not correct for scanned/walked strings with valid transitions but not ending in a final state. --- outlines/text/fsm.py | 9 ++++++++- tests/text/test_fsm.py | 26 +++++++++++++++++++------- 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/outlines/text/fsm.py b/outlines/text/fsm.py index 27b5af88e..f78f33e49 100644 --- a/outlines/text/fsm.py +++ b/outlines/text/fsm.py @@ -268,6 +268,9 @@ def _walk_fsm( accepted_states.append(_nonoptional(state)) + if full_match and last_final_idx - 1 != i: + return numba.typed.List.empty_list(numba.int64) + return accepted_states @@ -305,6 +308,9 @@ def walk_fsm( accepted_states.append(state) + if full_match and last_final_idx - 1 != i: + return [] + return accepted_states @@ -376,7 +382,7 @@ def process_token_string( res = set() vocab_string_len = len(token) - for end_idx, state_seq in find_partial_matches(fsm_info, token): + for end_idx, state_seq in find_partial_matches(fsm_info, token, full_match=False): if end_idx is not None and end_idx < vocab_string_len - 1: continue @@ -603,6 +609,7 @@ def state_scan_tokens( fsm_finals, token, start_state, + False, ) if state_seq is not None and len(state_seq) < len(token): diff --git a/tests/text/test_fsm.py b/tests/text/test_fsm.py index b19562c49..ce4a3647b 100644 --- a/tests/text/test_fsm.py +++ b/tests/text/test_fsm.py @@ -61,6 +61,18 @@ def test_walk_fsm(function): res = tuple(function(regex_fsm, "0", 1, full_match=True)) assert res == tuple() + regex_pattern = interegular.parse_pattern("0|[1-9][2-9]+") + regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) + + res = tuple(function(regex_fsm, "1", regex_fsm.initial, full_match=True)) + assert res == tuple() + + res = tuple(function(regex_fsm, "1", regex_fsm.initial, full_match=False)) + assert res == (2,) + + res = tuple(function(regex_fsm, "12", regex_fsm.initial, full_match=True)) + assert res == (2, 3) + pattern = interegular.parse_pattern(r"(?:[^\W\d]\w*|[\t \x0c]+)") fsm, _ = make_deterministic_fsm(pattern.to_fsm().reduce()) @@ -90,19 +102,19 @@ def to_python(res): res = to_python(find_partial_matches(def_fsm, "def")) assert res == {(2, (0, 1, 2, 3))} - res = to_python(find_partial_matches(def_fsm, "de")) + res = to_python(find_partial_matches(def_fsm, "de", full_match=False)) assert res == {(1, (0, 1, 2))} - res = to_python(find_partial_matches(def_fsm, "d")) + res = to_python(find_partial_matches(def_fsm, "d", full_match=False)) assert res == {(0, (0, 1))} res = to_python(find_partial_matches(def_fsm, "")) assert res == set() res = to_python(find_partial_matches(def_fsm, "df")) assert res == set() - res = to_python(find_partial_matches(def_fsm, "ef")) + res = to_python(find_partial_matches(def_fsm, "ef", full_match=False)) assert res == {(1, (1, 2, 3))} - res = to_python(find_partial_matches(def_fsm, "e")) + res = to_python(find_partial_matches(def_fsm, "e", full_match=False)) assert res == {(0, (1, 2))} - res = to_python(find_partial_matches(def_fsm, "f")) + res = to_python(find_partial_matches(def_fsm, "f", full_match=False)) assert res == {(0, (2, 3))} res = to_python(find_partial_matches(def_fsm, "ef foo", full_match=False)) assert res == {(1, (1, 2, 3))} @@ -112,7 +124,7 @@ def to_python(res): assert res == {(2, (0, 1, 2, 3))} # `NAME` can have multiple start states for this input - res = to_python(find_partial_matches(name_fsm, "d")) + res = to_python(find_partial_matches(name_fsm, "d", full_match=False)) assert res == {(0, (0, 1)), (0, (1, 1))} # Not this case res = to_python(find_partial_matches(name_fsm, "1d")) @@ -133,7 +145,7 @@ def to_python(res): float_fsm = float_fsm.fsm_info - res = to_python(find_partial_matches(float_fsm, ".")) + res = to_python(find_partial_matches(float_fsm, ".", full_match=False)) assert res == {(0, (3, 5)), (0, (4, 5)), (0, (0, 2))} joins_fsm, _ = make_deterministic_fsm(