Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom dataloader registry support #2932

Open
wants to merge 96 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
96 commits
Select commit Hold shift + click to select a range
7088e4b
copying CZI custom dataloader into our repo
ori-kron-wis Jul 28, 2024
cc72b05
added some fixes to the custom dataloader stuff
ori-kron-wis Jul 30, 2024
46048e3
Some suggestions
canergen Jul 30, 2024
14f343d
Changes to datamodule pipeline
canergen Jul 31, 2024
17282cd
Fixed attr_dict
canergen Jul 31, 2024
a4143f5
added some fixes based on custom data loader test
ori-kron-wis Aug 1, 2024
69abc47
Changes to dataloader
canergen Aug 6, 2024
dc21a3d
copying CZI custom dataloader into our repo
ori-kron-wis Jul 28, 2024
a1098b3
added some fixes to the custom dataloader stuff
ori-kron-wis Jul 30, 2024
b07216b
Some suggestions
canergen Jul 30, 2024
a578af1
Changes to datamodule pipeline
canergen Jul 31, 2024
42434ec
Fixed attr_dict
canergen Jul 31, 2024
3d0c890
added some fixes based on custom data loader test
ori-kron-wis Aug 1, 2024
eff5b1e
Changes to dataloader
canergen Aug 6, 2024
cbdc26e
Merge remote-tracking branch 'origin/ori-2907-custom-dataloader-regis…
ori-kron-wis Aug 7, 2024
18d65a6
add changes to tests and some merging with main following custom data…
ori-kron-wis Aug 7, 2024
4fe3ee1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 7, 2024
1110966
just put the cutom dataloder2 test under remarks so hook tests will r…
ori-kron-wis Aug 7, 2024
7972bdc
fixes
ori-kron-wis Aug 7, 2024
2d86c43
additional external models fixes once there is a registry
ori-kron-wis Aug 7, 2024
3c44d86
fixed a few failed tests
ori-kron-wis Aug 11, 2024
c0889d8
fix archesmixin init and added new custom dataloader test and github …
ori-kron-wis Aug 11, 2024
8fe043c
fix again for from __future__ import annotations
ori-kron-wis Aug 11, 2024
d8cf0f6
fix for run custom dataloader in github action
ori-kron-wis Aug 11, 2024
c41e8b2
rollback
ori-kron-wis Aug 11, 2024
6ec5d4d
added label to the new githubaction for custom dataloader
ori-kron-wis Aug 11, 2024
6bce317
fix for github action for custom dataloaders
ori-kron-wis Aug 12, 2024
1f4ae9d
another fix to custom dataloder test and github action
ori-kron-wis Aug 12, 2024
de1f30b
another fix to custom dataloder test and github action
ori-kron-wis Aug 12, 2024
49fa01e
another fix to custom dataloder test and github action
ori-kron-wis Aug 12, 2024
e33a935
another fix to custom dataloder test and github action
ori-kron-wis Aug 12, 2024
48627d9
another fix to custom dataloder test and github action
ori-kron-wis Aug 12, 2024
609094d
another fix to custom dataloder test and github action
ori-kron-wis Aug 12, 2024
8cf3517
another fix to custom dataloder test and github action
ori-kron-wis Aug 12, 2024
ba5a028
another fix to custom dataloder test and github action
ori-kron-wis Aug 12, 2024
a7dc3fe
another fix to custom dataloder test and github action
ori-kron-wis Aug 12, 2024
f3ff0f8
another fix to custom dataloder test and github action
ori-kron-wis Aug 12, 2024
083c76e
Merge branch 'main' into ori-2907-custom-dataloader-registry
ori-kron-wis Sep 9, 2024
70bba69
Merge branch 'main' into ori-2907-custom-dataloader-registry
ori-kron-wis Sep 15, 2024
8c75662
Merge branch 'main' into ori-2907-custom-dataloader-registry
ori-kron-wis Sep 16, 2024
b6eb2f1
Returned REGISTRY_KEYS for import, after was drop in recent merges
ori-kron-wis Sep 16, 2024
2979ea2
It is ok to drop it after scarches categorial covariates fix
ori-kron-wis Sep 16, 2024
67e9b34
Merge branch 'main' into ori-2907-custom-dataloader-registry
ori-kron-wis Sep 17, 2024
11fe33a
Merge branch 'main' into ori-2907-custom-dataloader-registry
ori-kron-wis Sep 17, 2024
4a648ff
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 17, 2024
e3831cb
moved to type checking blocks beucase of ruff updates
ori-kron-wis Sep 17, 2024
e1837bd
Merge branch 'main' into ori-2907-custom-dataloader-registry
ori-kron-wis Sep 26, 2024
bf4d3bf
Merge remote-tracking branch 'origin/main' into ori-2907-custom-datal…
ori-kron-wis Oct 7, 2024
2cc8ff9
updated for CZI custom dataloader test and backend
ori-kron-wis Oct 9, 2024
e62dc3a
Merge branch 'main' into ori-2907-custom-dataloader-registry
ori-kron-wis Oct 9, 2024
41fd877
added cellxgene-census folder as well for debug (will not be merged)
ori-kron-wis Oct 9, 2024
10ada9c
added cellxgene-census packge to run test
ori-kron-wis Oct 9, 2024
dd3649c
added torchdata packge to run test
ori-kron-wis Oct 9, 2024
c6acb5a
fixed the test workwflow
ori-kron-wis Oct 9, 2024
b35c6eb
adding the lamindb as well
ori-kron-wis Oct 10, 2024
1801604
fix the c.dataloders test
ori-kron-wis Oct 10, 2024
ed77a65
fix the c.dataloders test
ori-kron-wis Oct 10, 2024
fc831d5
fix the c.dataloders test
ori-kron-wis Oct 10, 2024
7400621
fix the c.dataloders test
ori-kron-wis Oct 10, 2024
47376ca
fix the c.dataloders test
ori-kron-wis Oct 10, 2024
f94f7fa
removed redundat functions in code base
ori-kron-wis Oct 13, 2024
962f043
Added scanvi support, including CZI datamodule fix for it
ori-kron-wis Oct 15, 2024
5c21d71
Merge remote-tracking branch 'origin/main' into ori-2907-custom-datal…
ori-kron-wis Oct 20, 2024
a8aeffe
updates from main
ori-kron-wis Dec 25, 2024
1283616
more updates from main
ori-kron-wis Dec 25, 2024
624ee72
Merge branch 'main' into ori-2907-custom-dataloader-registry
ori-kron-wis Dec 25, 2024
6d4f368
Merge remote-tracking branch 'origin/ori-2907-custom-dataloader-regis…
ori-kron-wis Dec 25, 2024
8ab01a4
updated related to tests
ori-kron-wis Dec 25, 2024
31e1d44
updated related to tests
ori-kron-wis Dec 25, 2024
93666fa
Running DataLoader MappedCollection
canergen Dec 30, 2024
1d1d6d3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 30, 2024
7695a8a
Fixed LaminDB dataloader
canergen Dec 31, 2024
e4d732a
Merge branch 'ori-2907-custom-dataloader-registry' of https://github.…
canergen Dec 31, 2024
a651442
LaminDB dataloader test.
canergen Dec 31, 2024
9767b8c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 31, 2024
719e740
Merge branch 'main' into ori-2907-custom-dataloader-registry
ori-kron-wis Dec 31, 2024
1a4c796
Merge remote-tracking branch 'origin/main' into ori-2907-custom-datal…
ori-kron-wis Jan 8, 2025
5666558
Changes for MappedCollection.
canergen Jan 8, 2025
c740dd2
Merge branch 'ori-2907-custom-dataloader-registry' of https://github.…
canergen Jan 8, 2025
61f2e27
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 8, 2025
874935b
Add other notebook for testing new dataloader
canergen Jan 9, 2025
f2c63bd
Merge branch 'ori-2907-custom-dataloader-registry' of https://github.…
canergen Jan 9, 2025
35d45c8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2025
38c670f
Merge remote-tracking branch 'origin/main' into ori-2907-custom-datal…
ori-kron-wis Jan 16, 2025
c93fc97
updates to test script
ori-kron-wis Jan 16, 2025
5045fc3
remove old test nb
ori-kron-wis Jan 16, 2025
55775f9
update test
ori-kron-wis Jan 16, 2025
7ccdf8d
update test
ori-kron-wis Jan 16, 2025
f88dc50
updated czi cdl
ori-kron-wis Jan 16, 2025
1f3ea11
updated czi cdl
ori-kron-wis Jan 16, 2025
d0ec46f
Merge remote-tracking branch 'origin/main' into ori-2907-custom-datal…
ori-kron-wis Jan 20, 2025
e304922
merge with main + updates
ori-kron-wis Feb 9, 2025
5ccd1ed
more updates
ori-kron-wis Feb 9, 2025
96a09d8
more updates
ori-kron-wis Feb 9, 2025
601d86f
more updates
ori-kron-wis Feb 9, 2025
2485bb6
pyproject update
ori-kron-wis Feb 10, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test_linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ jobs:
DISPLAY: :42
COLUMNS: 120
run: |
coverage run -m pytest -v --color=yes
coverage run -m pytest -v --color=yes -m "not custom_dataloader"
canergen marked this conversation as resolved.
Show resolved Hide resolved
coverage report

- uses: codecov/codecov-action@v4
Expand Down
89 changes: 89 additions & 0 deletions .github/workflows/test_linux_custom_dataloader.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
name: test (custom dataloaders)

on:
push:
branches: [main, "[0-9]+.[0-9]+.x"]
pull_request:
branches: [main, "[0-9]+.[0-9]+.x"]
types: [labeled, synchronize, opened]
schedule:
- cron: "0 10 * * *" # runs at 10:00 UTC (03:00 PST) every day
workflow_dispatch:

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
test:
# if PR has label "custom_dataloader" or "all tests" or if scheduled or manually triggered
if: >-
(
contains(github.event.pull_request.labels.*.name, 'custom_dataloader') ||
contains(github.event.pull_request.labels.*.name, 'all tests') ||
contains(github.event_name, 'schedule') ||
contains(github.event_name, 'workflow_dispatch')
)

runs-on: ${{ matrix.os }}

defaults:
run:
shell: bash -e {0} # -e to fail on error

strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python: ["3.11"]

name: integration

env:
OS: ${{ matrix.os }}
PYTHON: ${{ matrix.python }}

steps:
- uses: actions/checkout@v4

- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
cache: "pip"
cache-dependency-path: "**/pyproject.toml"

- name: Install dependencies
run: |
python -m pip install --upgrade pip wheel uv
python -m uv pip install --system "scvi-tools[tests] @ ."
python -m pip install scdataloader
python -m pip install cellxgene-census
python -m pip install tiledbsoma
python -m pip install s3fs
python -m pip install torchdata==0.9.0
python -m pip install psutil
python -m pip install lamindb
python -m pip install bionty==0.51.0
python -m pip install biomart

- name: Install Specific Branch of Repository
env:
GH_TOKEN: ${{ secrets.GH_TOKEN }}
run: |
git config --global url."https://${GH_TOKEN}:x-oauth-basic@github.com/".insteadOf "https://github.com/"
git clone --single-branch --branch ebezzi/census-scvi-datamodule https://github.com/ori-kron-wis/cellxgene-census.git
git clone --single-branch --branch main https://github.com/jkobject/scDataLoader.git

- name: Run specific custom dataloader pytest
env:
MPLBACKEND: agg
PLATFORM: ${{ matrix.os }}
DISPLAY: :42
COLUMNS: 120
run: |
coverage run -m pytest tests/dataloaders/test_custom_dataloader.py -v --color=yes --custom-dataloader-tests
coverage report

- uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ to [Semantic Versioning]. Full commit history is available in the
representation learning in single-cell RNA sequencing data {pr}`3015`, {pr}`3091`.
- Add {class}`scvi.external.RESOLVI` for bias correction in single-cell resolved spatial
transcriptomics {pr}`3144`.
- Add support for using Lamin custom dataloaders with {class}`scvi.model.SCVI`, {pr}`2932`.

#### Fixed

Expand Down
1 change: 1 addition & 0 deletions cellxgene-census
Submodule cellxgene-census added at fac658
2 changes: 1 addition & 1 deletion docs/tutorials/notebooks
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,11 @@ regseq = ["biopython>=1.81", "genomepy"]
scanpy = ["scanpy>=1.10", "scikit-misc"]
# for convinient files sharing
pooch = ["pooch"]
# for custom dataloders
dataloaders = ["lamindb","biomart","bionty","cellxgene_lamin"]

optional = [
"scvi-tools[autotune,aws,hub,pooch,regseq,scanpy]"
"scvi-tools[autotune,aws,hub,pooch,regseq,scanpy,dataloaders]"
]
tutorials = [
"cell2location",
Expand Down
9 changes: 9 additions & 0 deletions src/scvi/data/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch import as_tensor, sparse_csc_tensor, sparse_csr_tensor

from scvi import REGISTRY_KEYS, settings
from scvi.utils import attrdict

from . import _constants

Expand Down Expand Up @@ -150,6 +151,14 @@ def _set_data_in_registry(
setattr(adata, attr_name, attribute)


def _get_summary_stats_from_registry(registry: dict) -> attrdict:
summary_stats = {}
for field_registry in registry[_constants._FIELD_REGISTRIES_KEY].values():
field_summary_stats = field_registry[_constants._SUMMARY_STATS_KEY]
summary_stats.update(field_summary_stats)
return attrdict(summary_stats)


def _verify_and_correct_data_format(adata: AnnData, attr_name: str, attr_key: str | None):
"""Check data format and correct if necessary.

Expand Down
2 changes: 2 additions & 0 deletions src/scvi/dataloaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from ._ann_dataloader import AnnDataLoader
from ._concat_dataloader import ConcatDataLoader
from ._custom_dataloders import MappedCollectionDataModule
from ._data_splitting import (
DataSplitter,
DeviceBackedDataSplitter,
Expand All @@ -20,4 +21,5 @@
"DataSplitter",
"SemiSupervisedDataSplitter",
"BatchDistributedSampler",
"MappedCollectionDataModule",
]
194 changes: 194 additions & 0 deletions src/scvi/dataloaders/_custom_dataloders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import psutil
from lightning.pytorch import LightningDataModule
from torch.utils.data import DataLoader

import scvi

if TYPE_CHECKING:
import lamindb as ln
import numpy as np


class MappedCollectionDataModule(LightningDataModule):
def __init__(
self,
collection: ln.Collection,
batch_key: str | None = None,
label_key: str | None = None,
batch_size: int = 128,
**kwargs,
):
self._batch_size = batch_size
self._batch_key = batch_key
self._label_key = label_key
self._parallel = kwargs.pop("parallel", True)
# here we initialize MappedCollection to use in a pytorch DataLoader
self._dataset = collection.mapped(
obs_keys=self._batch_key, parallel=self._parallel, **kwargs
)
# need by scvi and lightning.pytorch
self._log_hyperparams = False
self.allow_zero_length_dataloader_with_multiple_devices = False

def close(self):
self._dataset.close()

def setup(self, stage):
pass

def train_dataloader(self):
return self._create_dataloader(shuffle=True)

def inference_dataloader(self):
"""Dataloader for inference with `on_before_batch_transfer` applied."""
dataloader = self._create_dataloader(shuffle=False, batch_size=4096)
return self._InferenceDataloader(dataloader, self.on_before_batch_transfer)

def _create_dataloader(self, shuffle, batch_size=None):
if self._parallel:
num_workers = psutil.cpu_count() - 1
worker_init_fn = self._dataset.torch_worker_init_fn
else:
num_workers = 0
worker_init_fn = None
if batch_size is None:
batch_size = self._batch_size
return DataLoader(
self._dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
worker_init_fn=worker_init_fn,
)

@property
def n_obs(self) -> int:
return self._dataset.n_obs

@property
def var_names(self) -> int:
return self._dataset.var_joint

@property
def n_vars(self) -> int:
return self._dataset.n_vars

@property
def n_batch(self) -> int:
if self._batch_key is None:
return 1
return len(self._dataset.encoders[self._batch_key])

@property
def n_labels(self) -> int:
if self._label_key is None:
return 1
return len(self._dataset.encoders[self._label_key])

@property
def labels(self) -> np.ndarray:
return self._dataset[self._label_key]

@property
def registry(self) -> dict:
return {
"scvi_version": scvi.__version__,
"model_name": "SCVI",
"setup_args": {
"layer": None,
"batch_key": self._batch_key,
"labels_key": self._label_key,
"size_factor_key": None,
"categorical_covariate_keys": None,
"continuous_covariate_keys": None,
},
"field_registries": {
"X": {
"data_registry": {"attr_name": "X", "attr_key": None},
"state_registry": {
"n_obs": self.n_obs,
"n_vars": self.n_vars,
"column_names": self.var_names,
},
"summary_stats": {"n_vars": self.n_vars, "n_cells": self.n_obs},
},
"batch": {
"data_registry": {"attr_name": "obs", "attr_key": "_scvi_batch"},
"state_registry": {
"categorical_mapping": self.batch_keys,
"original_key": self._batch_key,
},
"summary_stats": {"n_batch": self.n_batch},
},
"labels": {
"data_registry": {"attr_name": "obs", "attr_key": "_scvi_labels"},
"state_registry": {
"categorical_mapping": self.label_keys,
"original_key": self._label_key,
"unlabeled_category": "unlabeled",
},
"summary_stats": {"n_labels": self.n_labels},
},
"size_factor": {
"data_registry": {},
"state_registry": {},
"summary_stats": {},
},
"extra_categorical_covs": {
"data_registry": {},
"state_registry": {},
"summary_stats": {"n_extra_categorical_covs": 0},
},
"extra_continuous_covs": {
"data_registry": {},
"state_registry": {},
"summary_stats": {"n_extra_continuous_covs": 0},
},
},
"setup_method_name": "setup_anndata",
}

@property
def batch_keys(self) -> int:
if self._batch_key is None:
return None
return self._dataset.encoders[self._batch_key]

@property
def label_keys(self) -> int:
if self._label_key is None:
return None
return self._dataset.encoders[self._label_key]

def on_before_batch_transfer(
self,
batch,
dataloader_idx,
):
X_KEY: str = "X"
BATCH_KEY: str = "batch"
LABEL_KEY: str = "labels"

return {
X_KEY: batch["X"].float(),
BATCH_KEY: batch[self._batch_key][:, None] if self._batch_key is not None else None,
LABEL_KEY: 0,
}

class _InferenceDataloader:
"""Wrapper to apply `on_before_batch_transfer` during iteration."""

def __init__(self, dataloader, transform_fn):
self.dataloader = dataloader
self.transform_fn = transform_fn

def __iter__(self):
for batch in self.dataloader:
yield self.transform_fn(batch, dataloader_idx=None)

def __len__(self):
return len(self.dataloader)
3 changes: 2 additions & 1 deletion src/scvi/dataloaders/_data_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,8 @@ class is :class:`~scvi.dataloaders.SemiSupervisedDataLoader`,

def __init__(
self,
adata_manager: AnnDataManager,
adata_manager: AnnDataManager | None = None,
datamodule: pl.LightningDataModule | None = None,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
Expand Down
3 changes: 2 additions & 1 deletion src/scvi/external/resolvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ class RESOLVI(

def __init__(
self,
adata: AnnData,
adata: AnnData | None,
registry: dict | None = None,
n_hidden: int = 32,
n_hidden_encoder: int = 128,
n_latent: int = 10,
Expand Down
3 changes: 2 additions & 1 deletion src/scvi/external/stereoscope/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ class RNAStereoscope(UnsupervisedTrainingMixin, BaseModelClass):

def __init__(
self,
sc_adata: AnnData,
sc_adata: AnnData | None = None,
registry: dict | None = None,
**model_kwargs,
):
super().__init__(sc_adata)
Expand Down
1 change: 1 addition & 0 deletions src/scvi/external/stereoscope/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def __init__(
n_spots: int,
sc_params: tuple[np.ndarray],
prior_weight: Literal["n_obs", "minibatch"] = "n_obs",
**model_kwargs,
):
super().__init__()
# unpack and copy parameters
Expand Down
Loading
Loading