Skip to content

Commit

Permalink
traing_gpt2.py: Possibility to set device via environment variable
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardicus committed Apr 25, 2024
1 parent e86d63a commit ad5e912
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
run: python prepro_tinyshakespeare.py

- name: Train model
run: python train_gpt2.py
run: DEVICE=cpu python train_gpt2.py

- name: Compile training and testing program
run: make test_gpt2 train_gpt2
Expand Down
3 changes: 3 additions & 0 deletions train_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,9 @@ def write_tokenizer(enc, filename):
device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
# Override with environment variable if set
if 'DEVICE' in os.environ:
device = os.environ['DEVICE']
print(f"using device: {device}")

# create a context manager following the desired dtype and device
Expand Down

0 comments on commit ad5e912

Please sign in to comment.