From aa7fa8b8fed46e68859662518212ef559105ea55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rickard=20Hallerb=C3=A4ck?= Date: Thu, 25 Apr 2024 08:27:11 +0200 Subject: [PATCH] traing_gpt2.py: Possibility to set device via environment variable --- .github/workflows/ci.yml | 2 +- train_gpt2.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 97def0705..34088fafc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,7 +35,7 @@ jobs: run: python prepro_tinyshakespeare.py - name: Train model - run: python train_gpt2.py + run: python train_gpt2.py --device cpu - name: Compile training and testing program run: make test_gpt2 train_gpt2 diff --git a/train_gpt2.py b/train_gpt2.py index d52b25c3d..d1d213bec 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -328,6 +328,8 @@ def write_tokenizer(enc, filename): parser.add_argument("--num_iterations", type=int, default=10, help="number of iterations to run") parser.add_argument("--batch_size", type=int, default=4, help="batch size") parser.add_argument("--sequence_length", type=int, default=64, help="sequence length") + parser.add_argument("--device", type=str, default=None, help="device to use (e.g., 'cpu', 'cuda:0')") + args = parser.parse_args() B, T = args.batch_size, args.sequence_length assert 1 <= T <= 1024 @@ -339,6 +341,9 @@ def write_tokenizer(enc, filename): device = "cuda" elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): device = "mps" + # Override with device argument if set + if args.device: + device = args.device print(f"using device: {device}") # create a context manager following the desired dtype and device