From a01397592a264ff491d92e27e2abcfceac42eda1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rickard=20Hallerb=C3=A4ck?= Date: Thu, 25 Apr 2024 08:27:11 +0200 Subject: [PATCH] traing_gpt2.py: Possibility to set device via environment variable --- .github/workflows/ci.yml | 2 +- train_gpt2.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 97def0705..f44059f57 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,7 +35,7 @@ jobs: run: python prepro_tinyshakespeare.py - name: Train model - run: python train_gpt2.py + run: DEVICE=cpu python train_gpt2.py - name: Compile training and testing program run: make test_gpt2 train_gpt2 diff --git a/train_gpt2.py b/train_gpt2.py index d52b25c3d..2a93845de 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -339,6 +339,9 @@ def write_tokenizer(enc, filename): device = "cuda" elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): device = "mps" + # Override with environment variable if set + if 'DEVICE' in os.environ: + device = os.environ['DEVICE'] print(f"using device: {device}") # create a context manager following the desired dtype and device