Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stateful dataloader #1241

Closed
wants to merge 30 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e111150
initial commit
andrewkho Mar 28, 2024
ea7b7ac
clean up persistent_workers logic
andrewkho Mar 29, 2024
77a9868
tidy up
andrewkho Mar 29, 2024
d2135a9
simplify changes where it makes sense
andrewkho Mar 29, 2024
6477004
add statefulness tests in test_state_dict
andrewkho Mar 30, 2024
96eb0ed
add state_dict to dataloader
andrewkho Mar 30, 2024
eece67c
big update
andrewkho Apr 1, 2024
17d02c7
fix dataloader test
andrewkho Apr 1, 2024
ea42a6e
sampler, edge cases
andrewkho Apr 3, 2024
7f4d1fa
fix randomsampler state
andrewkho Apr 3, 2024
51fdd68
lint
andrewkho Apr 4, 2024
94e73f5
fix lint
andrewkho Apr 4, 2024
a4fe12e
lint
andrewkho Apr 4, 2024
884cf59
lint sampler
andrewkho Apr 4, 2024
73b5693
add psutil requirement
andrewkho Apr 4, 2024
fdbb5be
add iterable dataset fast-forward
andrewkho Apr 4, 2024
458fc21
mypy
andrewkho Apr 4, 2024
4d194b8
mypy
andrewkho Apr 4, 2024
fac6956
remove expected failure from dill test
andrewkho Apr 4, 2024
38b3a44
catch exceptions in state generation
andrewkho Apr 4, 2024
907f4e9
more mypy
andrewkho Apr 4, 2024
f13a123
add CI for StatefulDataLoader; fix generator pickling for windows/mac
andrewkho Apr 5, 2024
801f400
return state
andrewkho Apr 5, 2024
33f12c6
Fix sampler iter state restore
gokulavasan Apr 6, 2024
761a7cf
Fix random sampler patch state management. Also add unit tests for it
gokulavasan Apr 9, 2024
0c862b1
Capture base_seed as part of main snapshot and initialize worker loop…
gokulavasan Apr 13, 2024
c2cef69
change iterable dataset to use worker torch rng
gokulavasan Apr 13, 2024
2c3b338
Create constants for state dict key names
gokulavasan Apr 13, 2024
82a7872
Add unit test to assert error if worker count changes on resumption
gokulavasan Apr 13, 2024
a2101ca
Copyright and minor linter
gokulavasan Apr 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,5 @@ jobs:
pytest --no-header -v test --ignore=test/test_period.py --ignore=test/test_text_examples.py
--ignore=test/test_audio_examples.py --ignore=test/test_aistore.py
--ignore=test/dataloader2/test_dataloader2.py --ignore=test/dataloader2/test_mprs.py
--ignore=test/test_distributed.py
--ignore=test/test_distributed.py --ignore=test/stateful_dataloader/test_dataloader.py
--ignore=test/stateful_dataloader/test_state_dict.py
88 changes: 88 additions & 0 deletions .github/workflows/stateful_dataloader_ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
name: Run StatefulDataLoader Tests
on:
push:
branches:
- main
- release/*
tags:
pull_request:
types: [opened, synchronize, reopened, labeled]
branches:
- main
# For PR created by ghstack
- gh/*/*/base
- release/*

jobs:
test:
if:
${{ github.repository_owner == 'pytorch' && (github.event.action != 'labeled' ||
startsWith(github.event.label.name, 'ciflow')) }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os:
- macos-latest
- ubuntu-latest
- windows-latest
python-version:
- 3.8
- 3.9
- "3.10"
steps:
- name: Get PyTorch Channel
shell: bash
run: |
if [[ "${{ github.base_ref }}" == release/* ]] || [[ "${{ github.ref }}" == refs/heads/release/* ]] || [[ "${{ github.ref }}" == refs/tags/v* ]]; then
PT_CHANNEL="https://download.pytorch.org/whl/test/cpu/torch_test.html"
else
PT_CHANNEL="https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html"
fi
echo "value=$PT_CHANNEL" >> $GITHUB_OUTPUT
id: pytorch_channel
- name: Setup additional system libraries
if: startsWith( matrix.os, 'ubuntu' )
run: |
sudo add-apt-repository multiverse
sudo apt update
sudo apt install rar unrar libssl-dev libcurl4-openssl-dev zlib1g-dev
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Setup msbuild on Windows
if: matrix.os == 'windows-latest'
uses: microsoft/setup-msbuild@v1.1
- name: Set up Visual Studio shell
if: matrix.os == 'windows-latest'
uses: egor-tensin/vs-shell@v2
with:
arch: x64
- name: Check out source repository
uses: actions/checkout@v3
with:
submodules: recursive
- name: Install dependencies
run: |
pip3 install -r requirements.txt
pip3 install networkx
pip3 install --pre torch -f "${{ steps.pytorch_channel.outputs.value }}"
pip3 install cmake ninja
echo "/home/runner/.local/bin" >> $GITHUB_PATH
- name: Build TorchData
run: |
pip3 install .
env:
BUILD_S3: 1
- name: Install test requirements
run: pip3 install -r test/requirements.txt
- name: Test documentation examples
run: |
cd ./docs
pip3 install -r requirements.txt
make doctest
cd ..
- name: Run StatefulDataLoader tests with pytest
if: ${{ ! contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }}
run: pytest --no-header -v test/stateful_dataloader
1 change: 1 addition & 0 deletions test/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ datasets
graphviz
adlfs
awscli>=1.27.66
psutil
Loading
Loading