Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove default value checks #852

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/datachain/data_storage/db_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def has_table(self, name: str) -> bool:
return sa.inspect(self.engine).has_table(name)

@abstractmethod
def create_table(self, table: "Table", if_not_exists: bool = True) -> None: ...
def create_table(self, table: "Table", if_not_exists: bool = True) -> "Table": ...

@abstractmethod
def drop_table(self, table: "Table", if_exists: bool = False) -> None: ...
Expand Down
3 changes: 2 additions & 1 deletion src/datachain/data_storage/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,9 @@ def has_table(self, name: str) -> bool:
)
return bool(next(self.execute(query))[0])

def create_table(self, table: "Table", if_not_exists: bool = True) -> None:
def create_table(self, table: "Table", if_not_exists: bool = True) -> "Table":
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing this as it makes code better in Studio part where table columns are modified to allow nullable

self.execute(CreateTable(table, if_not_exists=if_not_exists))
return table

def drop_table(self, table: "Table", if_exists: bool = False) -> None:
self.execute(DropTable(table, if_exists=if_exists))
Expand Down
35 changes: 14 additions & 21 deletions tests/func/test_datachain_merge.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest

from datachain.lib.dc import DataChain
from datachain.sql.types import Int


@pytest.mark.parametrize(
Expand All @@ -11,16 +10,13 @@
)
@pytest.mark.parametrize("inner", [True, False])
def test_merge_union(cloud_test_catalog, inner, cloud_type):
catalog = cloud_test_catalog.catalog
session = cloud_test_catalog.session

src = cloud_test_catalog.src_uri

dogs = DataChain.from_storage(f"{src}/dogs/*", session=session)
cats = DataChain.from_storage(f"{src}/cats/*", session=session)

signal_default_value = Int.default_value(catalog.warehouse.db.dialect)

dogs1 = dogs.map(sig1=lambda: 1, output={"sig1": int})
dogs2 = dogs.map(sig2=lambda: 2, output={"sig2": int})
cats1 = cats.map(sig1=lambda: 1, output={"sig1": int})
Expand All @@ -37,8 +33,8 @@ def test_merge_union(cloud_test_catalog, inner, cloud_type):
]
else:
assert signals == [
("cats/cat1", 1, signal_default_value),
("cats/cat2", 1, signal_default_value),
("cats/cat1", 1, None),
("cats/cat2", 1, None),
("dogs/dog1", 1, 2),
("dogs/dog2", 1, 2),
("dogs/dog3", 1, 2),
Expand All @@ -55,16 +51,13 @@ def test_merge_union(cloud_test_catalog, inner, cloud_type):
@pytest.mark.parametrize("inner2", [True, False])
@pytest.mark.parametrize("inner3", [True, False])
def test_merge_multiple(cloud_test_catalog, inner1, inner2, inner3):
catalog = cloud_test_catalog.catalog
session = cloud_test_catalog.session

src = cloud_test_catalog.src_uri

dogs = DataChain.from_storage(f"{src}/dogs/*", session=session)
cats = DataChain.from_storage(f"{src}/cats/*", session=session)

signal_default_value = Int.default_value(catalog.warehouse.db.dialect)

dogs_and_cats = dogs | cats
dogs1 = dogs.map(sig1=lambda: 1, output={"sig1": int})
cats1 = cats.map(sig2=lambda: 2, output={"sig2": int})
Expand All @@ -80,22 +73,22 @@ def test_merge_multiple(cloud_test_catalog, inner1, inner2, inner3):
assert merged_signals == []
elif inner1:
assert merged_signals == [
("dogs/dog1", 1, signal_default_value),
("dogs/dog2", 1, signal_default_value),
("dogs/dog3", 1, signal_default_value),
("dogs/others/dog4", 1, signal_default_value),
("dogs/dog1", 1, None),
("dogs/dog2", 1, None),
("dogs/dog3", 1, None),
("dogs/others/dog4", 1, None),
]
elif inner2 and inner3:
assert merged_signals == [
("cats/cat1", signal_default_value, 2),
("cats/cat2", signal_default_value, 2),
("cats/cat1", None, 2),
("cats/cat2", None, 2),
]
else:
assert merged_signals == [
("cats/cat1", signal_default_value, 2),
("cats/cat2", signal_default_value, 2),
("dogs/dog1", 1, signal_default_value),
("dogs/dog2", 1, signal_default_value),
("dogs/dog3", 1, signal_default_value),
("dogs/others/dog4", 1, signal_default_value),
("cats/cat1", None, 2),
("cats/cat2", None, 2),
("dogs/dog1", 1, None),
("dogs/dog2", 1, None),
("dogs/dog3", 1, None),
("dogs/others/dog4", 1, None),
]
15 changes: 6 additions & 9 deletions tests/func/test_dataset_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,10 +743,9 @@ def test_join_with_binary_expression(
("dogs/others/dog4", "dogs/others/dog4"),
]
else:
string_default = String.default_value(catalog.warehouse.db.dialect)
expected = [
("cats/cat1", string_default),
("cats/cat2", string_default),
("cats/cat1", None),
("cats/cat2", None),
("dogs/dog1", "dogs/dog1"),
("dogs/dog2", "dogs/dog2"),
("dogs/dog3", "dogs/dog3"),
Expand Down Expand Up @@ -793,10 +792,9 @@ def test_join_with_combination_binary_expression_and_column_predicates(
("dogs/others/dog4", "dogs/others/dog4"),
]
else:
string_default = String.default_value(catalog.warehouse.db.dialect)
expected = [
("cats/cat1", string_default),
("cats/cat2", string_default),
("cats/cat1", None),
("cats/cat2", None),
("dogs/dog1", "dogs/dog1"),
("dogs/dog2", "dogs/dog2"),
("dogs/dog3", "dogs/dog3"),
Expand Down Expand Up @@ -918,10 +916,9 @@ def test_join_with_using_functions_in_expression(
("dogs/others/dog4", "dogs/others/dog4"),
]
else:
string_default = String.default_value(catalog.warehouse.db.dialect)
expected = [
("cats/cat1", string_default),
("cats/cat2", string_default),
("cats/cat1", None),
("cats/cat2", None),
("dogs/dog1", "dogs/dog1"),
("dogs/dog2", "dogs/dog2"),
("dogs/dog3", "dogs/dog3"),
Expand Down
29 changes: 19 additions & 10 deletions tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,17 +1052,13 @@ def test_parse_nested_json(tmp_dir, test_session):
# E.g. nAmE -> name, l--as@t -> l_as_t, etc
df1 = dc.select("na_me", "age", "city").to_pandas()

# In CH we replace None with '' for peforance reasons,
# have to handle it here
string_default = String.default_value(test_session.catalog.warehouse.db.dialect)

assert sorted(df1["na_me"]["first_select"].to_list()) == sorted(
d["first-SELECT"] for d in df["nA-mE"].to_list()
)
assert sorted(
df1["na_me"]["l_as_t"].to_list(), key=lambda x: (x is None, x)
) == sorted(
[d.get("l--as@t", string_default) for d in df["nA-mE"].to_list()],
[d.get("l--as@t", None) for d in df["nA-mE"].to_list()],
key=lambda x: (x is None, x),
)

Expand Down Expand Up @@ -1304,6 +1300,7 @@ def test_from_csv_null_collect(tmp_dir, test_session):
for i, row in enumerate(dc.collect()):
# None value in numeric column will get converted to nan.
if not height[i]:
print(row[1].height)
assert math.isnan(row[1].height)
else:
assert row[1].height == height[i]
Expand Down Expand Up @@ -1420,10 +1417,6 @@ def test_explode(tmp_dir, test_session, column_type, object_name, model_name):
object_name = object_name or "content_expl"
model_name = model_name or "ContentExplodedModel"

# In CH we have (atm at least) None converted to ''
# for performance reasons, so we need to handle this case
string_default = String.default_value(test_session.catalog.warehouse.db.dialect)

assert set(
dc.collect(
f"{object_name}.na_me.first_select",
Expand All @@ -1433,7 +1426,7 @@ def test_explode(tmp_dir, test_session, column_type, object_name, model_name):
) == {
("Alice", 25, "New York"),
("Bob", 30, "Los Angeles"),
("Charlie", 35, string_default),
("Charlie", 35, None),
("David", 40, "Houston"),
("Eva", 45, "Phoenix"),
("Ivan", 41, "San Francisco"),
Expand Down Expand Up @@ -2097,6 +2090,22 @@ def test_from_values_array_of_floats(test_session):
assert list(chain.order_by("emd").collect("emd")) == embeddings


def test_from_values_array_of_ints_with_nones(test_session):
ids = [1, 2]
embeddings = [[1, None], [4, 5]]
chain = DataChain.from_values(emd=embeddings, ids=ids, session=test_session)

assert list(chain.order_by("ids").collect("emd")) == embeddings


def test_from_values_with_nones(test_session):
ids = [1, 2, 3, 4]
embeddings = [100, None, 300, None]
chain = DataChain.from_values(emd=embeddings, ids=ids, session=test_session)

assert list(chain.order_by("ids").collect("emd")) == [100, None, 300, None]


def test_custom_model_with_nested_lists(test_session):
class Trace(BaseModel):
x: float
Expand Down
18 changes: 6 additions & 12 deletions tests/unit/lib/test_datachain_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from sqlalchemy import func

from datachain.lib.dc import C, DataChain, DatasetMergeError
from datachain.sql.types import Int, String
from tests.utils import skip_if_not_sqlite


Expand Down Expand Up @@ -52,8 +51,6 @@ def test_merge_objects(test_session):
ch2 = DataChain.from_values(team=team, session=test_session)
ch = ch1.merge(ch2, "emp.person.name", "team.player")

str_default = String.default_value(test_session.catalog.warehouse.db.dialect)

i = 0
j = 0
for items in ch.order_by("emp.person.name", "team.player").collect():
Expand All @@ -72,8 +69,8 @@ def test_merge_objects(test_session):
assert math.isclose(player.height, team[j].height, rel_tol=1e-7)
j += 1
else:
assert player.player == str_default
assert player.sport == str_default
assert player.player is None
assert player.sport is None
assert pd.isnull(player.weight)
assert pd.isnull(player.height)

Expand All @@ -95,9 +92,6 @@ def test_merge_objects_full_join(test_session, multiple_predicates):
else:
ch = ch1.merge(ch2, "emp.person.name", "team.player", full=True)

str_default = String.default_value(test_session.catalog.warehouse.db.dialect)
int_default = Int.default_value(test_session.catalog.warehouse.db.dialect)

i = 0
for items in ch.order_by("emp.person.name", "team.player").collect():
assert len(items) == 2
Expand All @@ -107,13 +101,13 @@ def test_merge_objects_full_join(test_session, multiple_predicates):
assert isinstance(player, TeamMember)

if player.player == "John":
assert empl.person.name == str_default
assert empl.person.age == int_default
assert empl.person.name is None
assert empl.person.age is None
continue

if empl.person.name == "Bob":
assert player.player == str_default
assert player.sport == str_default
assert player.player is None
assert player.sport is None
assert pd.isnull(player.weight)
assert pd.isnull(player.height)
continue
Expand Down
13 changes: 3 additions & 10 deletions tests/unit/lib/test_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from datachain.diff import CompareStatus, compare_and_split
from datachain.lib.dc import DataChain
from datachain.lib.file import File
from datachain.sql.types import Int64, String
from tests.utils import sorted_dicts


Expand Down Expand Up @@ -163,15 +162,13 @@ def test_compare_with_explicit_compare_fields(test_session, right_name):
status_col="diff",
)

string_default = String.default_value(test_session.catalog.warehouse.db.dialect)

expected = [
(CompareStatus.MODIFIED, 1, "John1", "New York"),
(CompareStatus.ADDED, 2, "Doe", "Boston"),
(
CompareStatus.DELETED,
3,
string_default if right_name == "other_name" else "Mark",
None if right_name == "other_name" else "Mark",
"Seattle",
),
(CompareStatus.SAME, 4, "Andy", "San Francisco"),
Expand Down Expand Up @@ -202,13 +199,11 @@ def test_compare_different_left_right_on_columns(test_session):
status_col="diff",
)

int_default = Int64.default_value(test_session.catalog.warehouse.db.dialect)

expected = [
(CompareStatus.SAME, 4, "Andy"),
(CompareStatus.ADDED, 2, "Doe"),
(CompareStatus.MODIFIED, 1, "John1"),
(CompareStatus.DELETED, int_default, "Mark"),
(CompareStatus.DELETED, None, "Mark"),
]

collect_fields = ["diff", "id", "name"]
Expand Down Expand Up @@ -316,8 +311,6 @@ def test_compare_additional_column_on_left(test_session):
session=test_session,
).save("ds2")

string_default = String.default_value(test_session.catalog.warehouse.db.dialect)

diff = ds1.compare(ds2, same=True, on=["id"], status_col="diff")

assert sorted_dicts(diff.to_records(), "id") == sorted_dicts(
Expand All @@ -328,7 +321,7 @@ def test_compare_additional_column_on_left(test_session):
"diff": CompareStatus.DELETED,
"id": 3,
"name": "Mark",
"city": string_default,
"city": None,
},
{"diff": CompareStatus.MODIFIED, "id": 4, "name": "Andy", "city": "Tokyo"},
],
Expand Down