Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add intelligent sampling in ParquetSource #778

Merged
merged 14 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 86 additions & 21 deletions docs/datasets/dataset_load.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,92 @@ You will be redirected to the dataset view once your data is loaded.

## From Python

### Creating a dataset

You can create a dataset from Python using [](#lilac.create_dataset). Lilac supports variety of data
sources, including CSV, JSON, HuggingFace datasets, Parquet, Pandas and more. See [](#lilac.sources)
dsmilkov marked this conversation as resolved.
Show resolved Hide resolved
for details on available sources.

Before we load any dataset, we should set the project directory which will be used to store all the
datasets we import. If not set, it defaults to the current working directory.

```python
import lilac as ll
ll.set_project_dir('~/my_project')
```

#### Huggingface

This example loads the `glue` dataset with the `ax` config from HuggingFace:
dsmilkov marked this conversation as resolved.
Show resolved Hide resolved

```python
config = ll.DatasetConfig(
namespace='local',
name='glue',
source=ll.HuggingFaceSource(dataset_name='glue', config_name='ax'))
# NOTE: You can pass a `project_dir` to `create_dataset` as the second argument.
dataset = ll.create_dataset(config)
```

#### CSV

```python
url = 'https://storage.googleapis.com/lilac-data/datasets/the_movies_dataset/the_movies_dataset.csv'
config = ll.DatasetConfig(
namespace='local', name='the_movies_dataset', source=ll.CSVSource(filepaths=[url]))
dsmilkov marked this conversation as resolved.
Show resolved Hide resolved
dataset = ll.create_dataset(config)
```

#### Parquet

The parquet reader can read from local files, S3 or GCS. If your dataset is sharded, you can use a
glob pattern to load multiple files.

**Sampling**

The `sample_size` and `shuffle_before_sampling` arguments are optional. When
dsmilkov marked this conversation as resolved.
Show resolved Hide resolved
`shuffle_before_sampling` is `True`, the reader will shuffle the entire dataset before sampling, but
this requires fetching the entire dataset. If your dataset is massive and you only want to load the
first `sample_size` rows, set `shuffle_before_sampling` to `False`. When you have many shards and
`shuffle_before_sampling` is `False`, the reader will try to sample a few rows from each shard, to
avoid any shard skew.

```python
source = ll.ParquetSource(
filepaths=['s3://lilac-public-data/test-*.parquet'],
dsmilkov marked this conversation as resolved.
Show resolved Hide resolved
sample_size=100,
shuffle_before_sampling=False)
dsmilkov marked this conversation as resolved.
Show resolved Hide resolved
config = ll.DatasetConfig(namespace='local', name='parquet-test', source=source)
dataset = ll.create_dataset(config)
```

#### JSON

The JSON reader can read from local files, S3 or GCS. If your dataset is sharded, you can use a glob
pattern to load multiple files. The reader supports both JSON and JSONL files.
dsmilkov marked this conversation as resolved.
Show resolved Hide resolved

```python
config = ll.DatasetConfig(
namespace='local',
name='news_headlines',
source=ll.JSONSource(filepaths=[
'https://raw.githubusercontent.com/explosion/prodigy-recipes/master/example-datasets/news_headlines.jsonl'
dsmilkov marked this conversation as resolved.
Show resolved Hide resolved
]))
dataset = ll.create_dataset(config)
```

#### Pandas
dsmilkov marked this conversation as resolved.
Show resolved Hide resolved

```python
url = 'https://storage.googleapis.com/lilac-data-us-east1/datasets/csv_datasets/the_movies_dataset/the_movies_dataset.csv'
df = pd.read_csv(url, low_memory=False)
config = ll.DatasetConfig(namespace='local', name='the_movies_dataset2', source=ll.PandasSource(df))
dataset = ll.create_dataset(config)
```

For details on all the source loaders, see [](#lilac.sources). For details on the dataset config,
see [](#lilac.DatasetConfig).

### Loading from lilac.yml

When you start a webserver, Lilac will automatically create a project for you in the given project
Expand Down Expand Up @@ -99,24 +185,3 @@ Or from the CLI:
```sh
lilac load --project_dir=~/my_lilac
```

### Loading an individual dataset

This example loads the `glue` dataset with the `ax` config from HuggingFace:

```python
# Set the global project directory to where project files will be stored.
ll.set_project_dir('~/my_project')

config = ll.DatasetConfig(
namespace='local',
name='glue',
source=ll.HuggingFaceSource(dataset_name='glue', config_name='ax'))

# NOTE: If you don't want to set a global project directory, you can pass the `project_dir` to `create_dataset` as the second argument.
dataset = ll.create_dataset(config)
```

For details on all the source loaders, see [](#lilac.sources).

For details on the dataset config, see [](#lilac.DatasetConfig).
4 changes: 0 additions & 4 deletions lilac.yml

This file was deleted.

6 changes: 3 additions & 3 deletions lilac/sources/csv_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ..schema import Item, arrow_schema_to_schema
from ..source import Source, SourceSchema
from ..utils import download_http_files
from .duckdb_utils import duckdb_setup
from .duckdb_utils import convert_path_to_duckdb, duckdb_setup

LINE_NUMBER_COLUMN = '__line_number__'

Expand Down Expand Up @@ -45,14 +45,14 @@ def setup(self) -> None:
duckdb_setup(self._con)

# DuckDB expects s3 protocol: https://duckdb.org/docs/guides/import/s3_import.html.
s3_filepaths = [path.replace('gs://', 's3://') for path in filepaths]
duckdb_paths = [convert_path_to_duckdb(path) for path in filepaths]

# NOTE: We use duckdb here to increase parallelism for multiple files.
# NOTE: We turn off the parallel reader because of https://github.com/lilacai/lilac/issues/373.
self._con.execute(f"""
CREATE SEQUENCE serial START 1;
CREATE VIEW t as (SELECT nextval('serial') as "{LINE_NUMBER_COLUMN}", * FROM read_csv_auto(
{s3_filepaths},
{duckdb_paths},
SAMPLE_SIZE=500000,
HEADER={self.header},
{f'NAMES={self.names},' if self.names else ''}
Expand Down
56 changes: 36 additions & 20 deletions lilac/sources/duckdb_utils.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,46 @@
"""Utils for duckdb."""
import os

import urllib.parse

import duckdb

from ..env import env, get_project_dir
from ..env import env


def duckdb_setup(con: duckdb.DuckDBPyConnection) -> None:
"""Setup DuckDB. This includes setting up the extensions directory and GCS access."""
con.execute(f"""
SET extension_directory='{os.path.join(get_project_dir(), '.duckdb')}';
"""Setup DuckDB. This includes setting up performance optimizations."""
con.execute("""
SET enable_http_metadata_cache=true;
SET enable_object_cache=true;
""")

region = env('GCS_REGION') or env('S3_REGION')
if region:
con.execute(f"SET s3_region='{region}")

access_key = env('GCS_ACCESS_KEY') or env('S3_ACCESS_KEY')
if access_key:
con.execute(f"SET s3_access_key_id='{access_key}")

secret_key = env('GCS_SECRET_KEY') or env('S3_SECRET_KEY')
if secret_key:
con.execute(f"SET s3_secret_access_key='{secret_key}'")

gcs_endpoint = 'storage.googleapis.com'
endpoint = env('S3_ENDPOINT') or (gcs_endpoint if env('GCS_REGION') else None)
if endpoint:
con.execute(f"SET s3_endpoint='{endpoint}'")
def convert_path_to_duckdb(filepath: str) -> str:
"""Convert a filepath to a duckdb filepath."""
scheme = urllib.parse.urlparse(filepath).scheme
options: dict[str, str] = {}
if scheme == '':
return filepath
elif scheme == 'gs':
options['s3_endpoint'] = 'storage.googleapis.com'
if env('GCS_REGION'):
options['s3_region'] = env('GCS_REGION')
if env('GCS_ACCESS_KEY'):
options['s3_access_key_id'] = env('GCS_ACCESS_KEY')
if env('GCS_SECRET_KEY'):
options['s3_secret_access_key'] = env('GCS_SECRET_KEY')
filepath = filepath.replace('gs://', 's3://')
elif scheme == 's3':
if env('S3_ENDPOINT'):
options['s3_endpoint'] = env('S3_ENDPOINT')
if env('S3_REGION'):
options['s3_region'] = env('S3_REGION')
if env('S3_ACCESS_KEY'):
options['s3_access_key_id'] = env('S3_ACCESS_KEY')
if env('S3_SECRET_KEY'):
options['s3_secret_access_key'] = env('S3_SECRET_KEY')
else:
raise ValueError(f'Unsupported scheme: {scheme}')
if options:
return f'{filepath}?{urllib.parse.urlencode(options, safe="+/")}'
return filepath
9 changes: 3 additions & 6 deletions lilac/sources/json_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ..schema import Item, arrow_schema_to_schema
from ..source import Source, SourceSchema
from ..utils import download_http_files
from .duckdb_utils import duckdb_setup
from .duckdb_utils import convert_path_to_duckdb, duckdb_setup


class JSONSource(Source):
Expand Down Expand Up @@ -40,14 +40,11 @@ def setup(self) -> None:
duckdb_setup(self._con)

# DuckDB expects s3 protocol: https://duckdb.org/docs/guides/import/s3_import.html.
s3_filepaths = [path.replace('gs://', 's3://') for path in filepaths]
duckdb_paths = [convert_path_to_duckdb(path) for path in filepaths]

# NOTE: We use duckdb here to increase parallelism for multiple files.
self._con.execute(f"""
CREATE VIEW t as (SELECT * FROM read_json_auto(
{s3_filepaths},
IGNORE_ERRORS=true
));
CREATE VIEW t as (SELECT * FROM read_json_auto({duckdb_paths}, IGNORE_ERRORS=true));
""")

res = self._con.execute('SELECT COUNT(*) FROM t').fetchone()
Expand Down
93 changes: 70 additions & 23 deletions lilac/sources/parquet_source.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
"""Parquet source."""
import random
from typing import ClassVar, Iterable, Optional, cast

import duckdb
import pyarrow as pa
from pydantic import Field, field_validator
from pydantic import Field, ValidationInfo, field_validator
from typing_extensions import override

from ..schema import Item, arrow_schema_to_schema
from ..schema import Item, Schema, arrow_schema_to_schema
from ..source import Source, SourceSchema
from ..sources.duckdb_utils import duckdb_setup
from ..sources.duckdb_utils import convert_path_to_duckdb, duckdb_setup
from ..utils import download_http_files

# Number of rows to read per batch.
ROWS_PER_BATCH_READ = 10_000
ROWS_PER_BATCH_READ = 50_000


class ParquetSource(Source):
Expand All @@ -29,9 +30,13 @@ class ParquetSource(Source):
'A list of paths to parquet files which live locally or remotely on GCS, S3, or Hadoop.')
sample_size: Optional[int] = Field(
title='Sample size', description='Number of rows to sample from the dataset', default=None)
shuffle_before_sampling: bool = Field(
default=False,
description=
'If true, the dataset will be shuffled before sampling, requiring a pass over the entire data.')

_source_schema: Optional[SourceSchema] = None
_reader: Optional[pa.RecordBatchReader] = None
_readers: list[pa.RecordBatchReader] = []
dsmilkov marked this conversation as resolved.
Show resolved Hide resolved
_con: Optional[duckdb.DuckDBPyConnection] = None

@field_validator('filepaths')
Expand All @@ -50,26 +55,48 @@ def validate_sample_size(cls, sample_size: int) -> int:
raise ValueError('sample_size must be greater than 0.')
return sample_size

@field_validator('shuffle_before_sampling')
@classmethod
def validate_shuffle_before_sampling(cls, shuffle_before_sampling: bool,
info: ValidationInfo) -> bool:
"""Validate shuffle before sampling."""
if shuffle_before_sampling and not info.data['sample_size']:
dsmilkov marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError('`shuffle_before_sampling` requires `sample_size` to be set.')
return shuffle_before_sampling

def _setup_sampling(self, duckdb_paths: list[str]) -> Schema:
assert self._con, 'setup() must be called first.'
if not self.shuffle_before_sampling and self.sample_size:
# Find each individual file.
glob_res: list[tuple[str]] = self._con.execute(
dsmilkov marked this conversation as resolved.
Show resolved Hide resolved
f'SELECT * FROM GLOB({duckdb_paths})').fetchall()
duckdb_files: list[str] = list(set([row[0] for row in glob_res]))
batch_size = max(1, min(self.sample_size // len(duckdb_files), ROWS_PER_BATCH_READ))
for duckdb_file in duckdb_files:
con = self._con.cursor()
dsmilkov marked this conversation as resolved.
Show resolved Hide resolved
duckdb_setup(con)
res = con.execute(f"""SELECT * FROM read_parquet('{duckdb_file}')""")
self._readers.append(res.fetch_record_batch(rows_per_batch=batch_size))
else:
sample_suffix = f'USING SAMPLE {self.sample_size}' if self.sample_size else ''
res = self._con.execute(f"""SELECT * FROM read_parquet({duckdb_paths}) {sample_suffix}""")
self._readers.append(res.fetch_record_batch(rows_per_batch=ROWS_PER_BATCH_READ))
return arrow_schema_to_schema(self._readers[0].schema)

@override
def setup(self) -> None:
filepaths = download_http_files(self.filepaths)
self._con = duckdb.connect(database=':memory:')
duckdb_setup(self._con)

# DuckDB expects s3 protocol: https://duckdb.org/docs/guides/import/s3_import.html.
s3_filepaths = [path.replace('gs://', 's3://') for path in filepaths]

# NOTE: We use duckdb here to increase parallelism for multiple files.
sample_suffix = f'USING SAMPLE {self.sample_size}' if self.sample_size else ''
self._con.execute(f"""
CREATE VIEW t as (SELECT * FROM read_parquet({s3_filepaths}) {sample_suffix});
""")
res = self._con.execute('SELECT COUNT(*) FROM t').fetchone()
duckdb_paths = [convert_path_to_duckdb(path) for path in filepaths]
res = self._con.execute(f'SELECT COUNT(*) FROM read_parquet({duckdb_paths})').fetchone()
num_items = cast(tuple[int], res)[0]
self._reader = self._con.execute('SELECT * from t').fetch_record_batch(
rows_per_batch=ROWS_PER_BATCH_READ)
# Create the source schema in prepare to share it between process and source_schema.
schema = arrow_schema_to_schema(self._reader.schema)
if self.sample_size:
self.sample_size = min(self.sample_size, num_items)
num_items = self.sample_size
schema = self._setup_sampling(duckdb_paths)
self._source_schema = SourceSchema(fields=schema.fields, num_items=num_items)

@override
Expand All @@ -81,10 +108,30 @@ def source_schema(self) -> SourceSchema:
@override
def process(self) -> Iterable[Item]:
"""Process the source."""
assert self._reader and self._con, 'setup() must be called first.'

for batch in self._reader:
yield from batch.to_pylist()

self._reader.close()
assert self._con, 'setup() must be called first.'

items_yielded = 0
done = False
while not done:
index = random.randint(0, len(self._readers) - 1)
dsmilkov marked this conversation as resolved.
Show resolved Hide resolved
reader = self._readers[index]
batch = None
try:
batch = reader.read_next_batch()
except StopIteration:
del self._readers[index]
dsmilkov marked this conversation as resolved.
Show resolved Hide resolved
if not self._readers:
done = True
break
continue
items = batch.to_pylist()
for item in items:
yield item
items_yielded += 1
if self.sample_size and items_yielded == self.sample_size:
done = True
break

for reader in self._readers:
reader.close()
self._con.close()
Loading