Skip to content

Commit

Permalink
Change it to --device
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardicus committed Apr 25, 2024
1 parent a013975 commit 9a30adb
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
run: python prepro_tinyshakespeare.py

- name: Train model
run: DEVICE=cpu python train_gpt2.py
run: python train_gpt2.py --device cpu

- name: Compile training and testing program
run: make test_gpt2 train_gpt2
Expand Down
8 changes: 5 additions & 3 deletions train_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,8 @@ def write_tokenizer(enc, filename):
parser.add_argument("--num_iterations", type=int, default=10, help="number of iterations to run")
parser.add_argument("--batch_size", type=int, default=4, help="batch size")
parser.add_argument("--sequence_length", type=int, default=64, help="sequence length")
parser.add_argument("--device", type=str, default=None, help="device to use (e.g., 'cpu', 'cuda:0')")

args = parser.parse_args()
B, T = args.batch_size, args.sequence_length
assert 1 <= T <= 1024
Expand All @@ -339,9 +341,9 @@ def write_tokenizer(enc, filename):
device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
# Override with environment variable if set
if 'DEVICE' in os.environ:
device = os.environ['DEVICE']
# Override with device argument if set
if args.device:
device = args.device
print(f"using device: {device}")

# create a context manager following the desired dtype and device
Expand Down

0 comments on commit 9a30adb

Please sign in to comment.