Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sampling to our ParquetSource #773

Merged
merged 5 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ DUCKDB_USE_VIEWS=0
# GCS_REGION=
# GCS_ACCESS_KEY=
# GCS_SECRET_KEY=
# S3_REGION=
dsmilkov marked this conversation as resolved.
Show resolved Hide resolved
# S3_ENDPOINT=
# S3_ACCESS_KEY=
# S3_SECRET_KEY=

# Get key from https://platform.openai.com/account/api-keys
# OPENAI_API_KEY=
Expand Down
3 changes: 1 addition & 2 deletions lilac/data/dataset_compute_signal_chain_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from pytest_mock import MockerFixture
from typing_extensions import override

from lilac.sources.source_registry import clear_source_registry, register_source

from ..embeddings.vector_store import VectorDBIndex
from ..schema import (
EMBEDDING_KEY,
Expand All @@ -31,6 +29,7 @@
clear_signal_registry,
register_signal,
)
from ..sources.source_registry import clear_source_registry, register_source
from .dataset import DatasetManifest
from .dataset_test_utils import (
TEST_DATASET_NAME,
Expand Down
3 changes: 1 addition & 2 deletions lilac/data/dataset_compute_signal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from pytest_mock import MockerFixture
from typing_extensions import override

from lilac.sources.source_registry import clear_source_registry, register_source

from ..concepts.concept import ExampleIn
from ..concepts.db_concept import ConceptUpdate, DiskConceptDB
from ..schema import (
Expand All @@ -30,6 +28,7 @@
register_signal,
)
from ..signals.concept_scorer import ConceptSignal
from ..sources.source_registry import clear_source_registry, register_source
from . import dataset_utils as dataset_utils_module
from .dataset import Column, DatasetManifest, GroupsSortBy, SortOrder
from .dataset_test_utils import (
Expand Down
3 changes: 1 addition & 2 deletions lilac/data/dataset_labels_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
from freezegun import freeze_time
from pytest_mock import MockerFixture

from lilac.sources.source_registry import clear_source_registry, register_source

from ..schema import PATH_WILDCARD, ROWID, Item, field, schema
from ..sources.source_registry import clear_source_registry, register_source
from .dataset import DatasetManifest, SelectGroupsResult, SortOrder
from .dataset_test_utils import TestDataMaker, TestSource

Expand Down
3 changes: 1 addition & 2 deletions lilac/data/dataset_map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
from freezegun import freeze_time
from typing_extensions import override

from lilac.sources.source_registry import clear_source_registry, register_source

from ..schema import PATH_WILDCARD, VALUE_KEY, Field, Item, MapInfo, RichData, field, schema
from ..signal import TextSignal, clear_signal_registry, register_signal
from ..sources.source_registry import clear_source_registry, register_source
from .dataset import DatasetManifest
from .dataset_test_utils import (
TEST_DATASET_NAME,
Expand Down
7 changes: 3 additions & 4 deletions lilac/data/dataset_select_rows_udf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
from pytest import approx
from typing_extensions import override

from lilac.concepts.concept import ExampleIn
from lilac.concepts.db_concept import ConceptUpdate, DiskConceptDB
from lilac.signals.concept_scorer import ConceptSignal

from ..concepts.concept import ExampleIn
from ..concepts.db_concept import ConceptUpdate, DiskConceptDB
from ..embeddings.vector_store import VectorDBIndex
from ..schema import (
ROWID,
Expand All @@ -32,6 +30,7 @@
clear_signal_registry,
register_signal,
)
from ..signals.concept_scorer import ConceptSignal
from .dataset import BinaryFilterTuple, Column, SortOrder
from .dataset_test_utils import TestDataMaker, enriched_item

Expand Down
3 changes: 1 addition & 2 deletions lilac/data/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
import pytest
from typing_extensions import override

from lilac.sources.source_registry import clear_source_registry, register_source

from ..config import DatasetConfig, EmbeddingConfig, SignalConfig
from ..schema import EMBEDDING_KEY, ROWID, Field, Item, RichData, field, lilac_embedding, schema
from ..signal import TextEmbeddingSignal, TextSignal, clear_signal_registry, register_signal
from ..sources.source_registry import clear_source_registry, register_source
from .dataset import Column, DatasetManifest, dataset_config_from_manifest
from .dataset_test_utils import (
TEST_DATASET_NAME,
Expand Down
3 changes: 1 addition & 2 deletions lilac/load_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from pytest_mock import MockerFixture
from typing_extensions import override

from lilac.sources.source_registry import clear_source_registry, register_source

from .config import Config, DatasetConfig, DatasetSettings, DatasetUISettings
from .data.dataset import SourceManifest
from .data.dataset_duckdb import read_source_manifest
Expand All @@ -19,6 +17,7 @@
from .project import read_project_config
from .schema import PARQUET_FILENAME_PREFIX, ROWID, Item, schema
from .source import Source, SourceSchema
from .sources.source_registry import clear_source_registry, register_source
from .test_utils import fake_uuid, read_items
from .utils import DATASETS_DIR_NAME

Expand Down
2 changes: 1 addition & 1 deletion lilac/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from typing_extensions import TypedDict

from lilac.utils import is_primitive, log
from .utils import is_primitive, log

MANIFEST_FILENAME = 'manifest.json'
PARQUET_FILENAME_PREFIX = 'data'
Expand Down
5 changes: 2 additions & 3 deletions lilac/signals/cluster_dbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
from sklearn.cluster import DBSCAN
from typing_extensions import override

from lilac.embeddings.vector_store import VectorDBIndex
from lilac.utils import DebugTimer

from ..embeddings.embedding import get_embed_fn
from ..embeddings.vector_store import VectorDBIndex
from ..schema import Field, Item, PathKey, RichData, SignalInputType, SpanVector, field, lilac_span
from ..signal import VectorSignal
from ..utils import DebugTimer

CLUSTER_ID = 'cluster_id'
MIN_SAMPLES = 5
Expand Down
5 changes: 2 additions & 3 deletions lilac/signals/cluster_hdbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
from sklearn.cluster import HDBSCAN
from typing_extensions import override

from lilac.embeddings.vector_store import VectorDBIndex
from lilac.utils import DebugTimer

from ..embeddings.embedding import get_embed_fn
from ..embeddings.vector_store import VectorDBIndex
from ..schema import Field, Item, PathKey, RichData, SignalInputType, SpanVector, field, lilac_span
from ..signal import VectorSignal
from ..utils import DebugTimer

CLUSTER_ID = 'cluster_id'
MIN_CLUSTER_SIZE = 5
Expand Down
2 changes: 1 addition & 1 deletion lilac/sources/csv_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ def setup(self) -> None:
filepaths = download_http_files(self.filepaths)

self._con = duckdb.connect(database=':memory:')
duckdb_setup(self._con)

# DuckDB expects s3 protocol: https://duckdb.org/docs/guides/import/s3_import.html.
s3_filepaths = [path.replace('gs://', 's3://') for path in filepaths]

# NOTE: We use duckdb here to increase parallelism for multiple files.
# NOTE: We turn off the parallel reader because of https://github.com/lilacai/lilac/issues/373.
self._con.execute(f"""
{duckdb_setup(self._con)}
CREATE SEQUENCE serial START 1;
CREATE VIEW t as (SELECT nextval('serial') as "{LINE_NUMBER_COLUMN}", * FROM read_csv_auto(
{s3_filepaths},
Expand Down
26 changes: 17 additions & 9 deletions lilac/sources/duckdb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,25 @@
from ..env import env, get_project_dir


def duckdb_setup(con: duckdb.DuckDBPyConnection) -> str:
def duckdb_setup(con: duckdb.DuckDBPyConnection) -> None:
"""Setup DuckDB. This includes setting up the extensions directory and GCS access."""
con.execute(f"""
SET extension_directory='{os.path.join(get_project_dir(), '.duckdb')}';
""")

if env('GCS_REGION'):
return f"""
SET s3_region='{env('GCS_REGION')}';
SET s3_access_key_id='{env('GCS_ACCESS_KEY')}';
SET s3_secret_access_key='{env('GCS_SECRET_KEY')}';
SET s3_endpoint='storage.googleapis.com';
"""
return ''
region = env('GCS_REGION') or env('S3_REGION')
if region:
con.execute(f"SET s3_region='{region}")

access_key = env('GCS_ACCESS_KEY') or env('S3_ACCESS_KEY')
if access_key:
con.execute(f"SET s3_access_key_id='{access_key}")

secret_key = env('GCS_SECRET_KEY') or env('S3_SECRET_KEY')
if secret_key:
con.execute(f"SET s3_secret_access_key='{secret_key}'")

gcs_endpoint = 'storage.googleapis.com'
endpoint = env('S3_ENDPOINT') or (gcs_endpoint if env('GCS_REGION') else None)
if endpoint:
con.execute(f"SET s3_endpoint='{endpoint}'")
5 changes: 2 additions & 3 deletions lilac/sources/json_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,14 @@ class JSONSource(Source):
def setup(self) -> None:
# Download JSON files to local cache if they are via HTTP to speed up duckdb.
filepaths = download_http_files(self.filepaths)

self._con = duckdb.connect(database=':memory:')
duckdb_setup(self._con)

# DuckDB expects s3 protocol: https://duckdb.org/docs/guides/import/s3_import.html.
s3_filepaths = [path.replace('gs://', 's3://') for path in filepaths]

# NOTE: We use duckdb here to increase parallelism for multiple files.
self._con.execute(f"""
{duckdb_setup(self._con)}
CREATE VIEW t as (SELECT * FROM read_json_auto(
{s3_filepaths},
IGNORE_ERRORS=true
Expand All @@ -62,7 +61,7 @@ def setup(self) -> None:
@override
def source_schema(self) -> SourceSchema:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not for now but maybe we should make schemas optional and let duckdb infer types to reduce cognitive overhead of both sources and signals

then signals are very close to a map

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's our pq.ParquetWriter that needs a schema ahead of time to setup a writer, before writing a single row to disk. And that schema needs to be consistent with 100% of the rows that are going in that writer to avoid write error. That means we need to see the entire data in order to correctly infer the schema, if not provided by the user. Or we circumvent our writer and get duckdb to read the format and dump to paquet directly.

Copy link
Collaborator Author

@dsmilkov dsmilkov Oct 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that pq.ParquetWriter also doesn't hold everything in memory, it dumbs to parquet every 128MB row_group_buffer_size with 10k items per rowgroup.

"""Return the source schema."""
assert self._source_schema is not None
assert self._source_schema is not None, 'setup() must be called first.'
return self._source_schema

@override
Expand Down
62 changes: 50 additions & 12 deletions lilac/sources/parquet_source.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Parquet source."""
from typing import ClassVar, Iterable, Optional
from typing import ClassVar, Iterable, Optional, cast

import duckdb
import pyarrow as pa
import pyarrow.parquet as pq
from pydantic import Field
from pydantic import Field, field_validator
from typing_extensions import override

from ..schema import Item, arrow_schema_to_schema
from ..source import Source, SourceSchema
from ..sources.duckdb_utils import duckdb_setup
from ..utils import download_http_files


class ParquetSource(Source):
Expand All @@ -22,17 +24,49 @@ class ParquetSource(Source):
filepaths: list[str] = Field(
description=
'A list of paths to parquet files which live locally or remotely on GCS, S3, or Hadoop.')
sample_size: Optional[int] = Field(
title='Sample size', description='Number of rows to sample from the dataset', default=None)

_source_schema: Optional[SourceSchema] = None
_table: Optional[pa.Table] = None
_reader: Optional[pa.RecordBatchReader] = None
_con: Optional[duckdb.DuckDBPyConnection] = None

@field_validator('filepaths')
@classmethod
def validate_filepaths(cls, filepaths: list[str]) -> list[str]:
"""Validate filepaths."""
if not filepaths:
raise ValueError('filepaths must be non-empty.')
return filepaths

@field_validator('sample_size')
@classmethod
def validate_sample_size(cls, sample_size: int) -> int:
"""Validate sample size."""
if sample_size < 1:
raise ValueError('sample_size must be greater than 0.')
return sample_size

@override
def setup(self) -> None:
assert self.filepaths, 'filepaths must be specified.'
self._table = pa.concat_tables([pq.read_table(f) for f in self.filepaths])
self._source_schema = SourceSchema(
fields=arrow_schema_to_schema(pq.read_schema(self.filepaths[0])).fields,
num_items=self._table.num_rows)
filepaths = download_http_files(self.filepaths)
self._con = duckdb.connect(database=':memory:')
duckdb_setup(self._con)

# DuckDB expects s3 protocol: https://duckdb.org/docs/guides/import/s3_import.html.
s3_filepaths = [path.replace('gs://', 's3://') for path in filepaths]

# NOTE: We use duckdb here to increase parallelism for multiple files.
sample_suffix = f'USING SAMPLE {self.sample_size}' if self.sample_size else ''
self._con.execute(f"""
CREATE VIEW t as (SELECT * FROM read_parquet({s3_filepaths}) {sample_suffix});
""")
res = self._con.execute('SELECT COUNT(*) FROM t').fetchone()
num_items = cast(tuple[int], res)[0]
self._reader = self._con.execute('SELECT * from t').fetch_record_batch(rows_per_batch=10_000)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these computations seem to belong in process(), not in setup()

Copy link
Collaborator Author

@dsmilkov dsmilkov Oct 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great q.

self._reader = self._con.execute('SELECT * from t').fetch_record_batch(rows_per_batch=10_000)

returns a lazy iterator, so no data is being read yet, but we found that executing this in setup catches a lot of "setup" bugs like file not found, unrecognized parquet format (broken head), unauthorized S3/GCS bucket read etc.

In addition to this, once you have a reader , you can read the inferred schema before reading the data, and our sources need the schema before process() so they can setup a parquet writer with buffer ahead of time.

dsmilkov marked this conversation as resolved.
Show resolved Hide resolved
# Create the source schema in prepare to share it between process and source_schema.
schema = arrow_schema_to_schema(self._reader.schema)
self._source_schema = SourceSchema(fields=schema.fields, num_items=num_items)

@override
def source_schema(self) -> SourceSchema:
Expand All @@ -43,6 +77,10 @@ def source_schema(self) -> SourceSchema:
@override
def process(self) -> Iterable[Item]:
"""Process the source."""
assert self._table is not None, 'setup() must be called first.'
for row in self._table.to_pylist():
yield row
assert self._reader and self._con, 'setup() must be called first.'

for batch in self._reader:
yield from batch.to_pylist()

self._reader.close()
self._con.close()
32 changes: 32 additions & 0 deletions lilac/sources/parquet_source_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import pyarrow as pa
import pyarrow.parquet as pq
import pytest
from pydantic import ValidationError

from ..schema import schema
from ..source import SourceSchema
Expand Down Expand Up @@ -37,3 +39,33 @@ def test_simple_rows(tmp_path: pathlib.Path) -> None:

items = list(source.process())
assert items == [{'name': 'a', 'age': 1}, {'name': 'b', 'age': 2}, {'name': 'c', 'age': 3}]


def test_sampling(tmp_path: pathlib.Path) -> None:
table = pa.Table.from_pylist([{
'name': 'a',
'age': 1
}, {
'name': 'b',
'age': 2
}, {
'name': 'c',
'age': 3
}])

out_file = os.path.join(tmp_path, 'test.parquet')
pq.write_table(table, out_file)

for sample_size in range(1, 4):
source = ParquetSource(filepaths=[out_file], sample_size=sample_size)
source.setup()
items = list(source.process())
assert len(items) == sample_size


def test_validation() -> None:
with pytest.raises(ValidationError, match='filepaths must be non-empty'):
ParquetSource(filepaths=[])

with pytest.raises(ValidationError, match='sample_size must be greater than 0'):
ParquetSource(filepaths=['gs://lilac/test.parquet'], sample_size=0)
5 changes: 2 additions & 3 deletions lilac/sources/sqlite_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@
from pydantic import Field
from typing_extensions import override

from lilac.utils import file_exists

from ..schema import Item, arrow_schema_to_schema
from ..source import Source, SourceSchema
from ..utils import file_exists
from .duckdb_utils import duckdb_setup

router = APIRouter()
Expand Down Expand Up @@ -48,12 +47,12 @@ class SQLiteSource(Source):
@override
def setup(self) -> None:
self._con = duckdb.connect(database=':memory:')
duckdb_setup(self._con)

# DuckDB expects s3 protocol: https://duckdb.org/docs/guides/import/s3_import.html.
db_file = self.db_file.replace('gs://', 's3://')

self._con.execute(f"""
{duckdb_setup(self._con)}
CREATE VIEW t as (SELECT * FROM sqlite_scan('{db_file}', '{self.table}'));
""")

Expand Down
Loading