From 31620b5a0648dde330977c84e07cdf27c5021fdc Mon Sep 17 00:00:00 2001 From: wd0517 Date: Wed, 8 May 2024 14:12:47 +0800 Subject: [PATCH] fix: handle empty vector in decode_vector --- tests/peewee/test_peewee.py | 7 +++++++ tests/sqlalchemy/test_sqlalchemy.py | 10 ++++++++++ tidb_vector/utils.py | 3 +++ 3 files changed, 20 insertions(+) diff --git a/tests/peewee/test_peewee.py b/tests/peewee/test_peewee.py index 2dd081e..866611a 100644 --- a/tests/peewee/test_peewee.py +++ b/tests/peewee/test_peewee.py @@ -48,6 +48,13 @@ def teardown_class(self): def setup_method(self): Item1Model.truncate_table() + def test_empty_vector(self): + Item1Model.create(embedding=[]) + assert Item1Model.select().count() == 1 + item1 = Item1Model.get() + assert np.array_equal(item1.embedding, np.array([])) + assert item1.embedding.dtype == np.float32 + def test_insert_get_record(self): Item1Model.create(embedding=[1, 2, 3]) assert Item1Model.select().count() == 1 diff --git a/tests/sqlalchemy/test_sqlalchemy.py b/tests/sqlalchemy/test_sqlalchemy.py index c3d3b79..95e5630 100644 --- a/tests/sqlalchemy/test_sqlalchemy.py +++ b/tests/sqlalchemy/test_sqlalchemy.py @@ -56,6 +56,16 @@ def test_insert_get_record(self): assert np.array_equal(item1.embedding, np.array([1, 2, 3])) assert item1.embedding.dtype == np.float32 + def test_empty_vector(self): + with Session() as session: + item1 = Item1Model(embedding=[]) + session.add(item1) + session.commit() + assert session.query(Item1Model).count() == 1 + item1 = session.query(Item1Model).first() + assert np.array_equal(item1.embedding, np.array([])) + assert item1.embedding.dtype == np.float32 + def test_get_with_different_dimensions(self): with Session() as session: item1 = Item1Model(embedding=[1, 2, 3]) diff --git a/tidb_vector/utils.py b/tidb_vector/utils.py index f34d2db..c629682 100644 --- a/tidb_vector/utils.py +++ b/tidb_vector/utils.py @@ -29,4 +29,7 @@ def decode_vector(value): if isinstance(value, bytes): value = value.decode("utf-8") + if value == "[]": + return np.array([], dtype=np.float32) + return np.array(value[1:-1].split(","), dtype=np.float32)