Skip to content

Commit be00d62

Browse files
committed
Add an easy way to run a script for a few steps only
I've wanted this tool for a while, figured I should just propose it. Often I need to test out a script or colab I did not write, and just want to run a few train steps without for every fit call without finding every call to fit in the script. This adds a debugging tool to do just that. ``` KERAS_MAX_EPOCHS=1 KERAS_MAX_STEPS=5 python train.py ```
1 parent 81c5097 commit be00d62

File tree

8 files changed

+154
-1
lines changed

8 files changed

+154
-1
lines changed

keras/api/_tf_keras/keras/config/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,17 @@
1717
from keras.src.backend.config import (
1818
is_flash_attention_enabled as is_flash_attention_enabled,
1919
)
20+
from keras.src.backend.config import max_epochs as max_epochs
21+
from keras.src.backend.config import max_steps_per_epoch as max_steps_per_epoch
2022
from keras.src.backend.config import set_epsilon as set_epsilon
2123
from keras.src.backend.config import set_floatx as set_floatx
2224
from keras.src.backend.config import (
2325
set_image_data_format as set_image_data_format,
2426
)
27+
from keras.src.backend.config import set_max_epochs as set_max_epochs
28+
from keras.src.backend.config import (
29+
set_max_steps_per_epoch as set_max_steps_per_epoch,
30+
)
2531
from keras.src.dtype_policies.dtype_policy import dtype_policy as dtype_policy
2632
from keras.src.dtype_policies.dtype_policy import (
2733
set_dtype_policy as set_dtype_policy,

keras/api/config/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,17 @@
1717
from keras.src.backend.config import (
1818
is_flash_attention_enabled as is_flash_attention_enabled,
1919
)
20+
from keras.src.backend.config import max_epochs as max_epochs
21+
from keras.src.backend.config import max_steps_per_epoch as max_steps_per_epoch
2022
from keras.src.backend.config import set_epsilon as set_epsilon
2123
from keras.src.backend.config import set_floatx as set_floatx
2224
from keras.src.backend.config import (
2325
set_image_data_format as set_image_data_format,
2426
)
27+
from keras.src.backend.config import set_max_epochs as set_max_epochs
28+
from keras.src.backend.config import (
29+
set_max_steps_per_epoch as set_max_steps_per_epoch,
30+
)
2531
from keras.src.dtype_policies.dtype_policy import dtype_policy as dtype_policy
2632
from keras.src.dtype_policies.dtype_policy import (
2733
set_dtype_policy as set_dtype_policy,

keras/src/backend/config.py

+71-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
# Default backend: TensorFlow.
1616
_BACKEND = "tensorflow"
1717

18+
# Cap run duration for debugging.
19+
_MAX_EPOCHS = None
20+
_MAX_STEPS_PER_EPOCH = None
21+
1822

1923
@keras_export(["keras.config.floatx", "keras.backend.floatx"])
2024
def floatx():
@@ -304,7 +308,10 @@ def keras_home():
304308
_backend = os.environ["KERAS_BACKEND"]
305309
if _backend:
306310
_BACKEND = _backend
307-
311+
if "KERAS_MAX_EPOCHS" in os.environ:
312+
_MAX_EPOCHS = int(os.environ["KERAS_MAX_EPOCHS"])
313+
if "KERAS_MAX_STEPS_PER_EPOCH" in os.environ:
314+
_MAX_STEPS_PER_EPOCH = int(os.environ["KERAS_MAX_STEPS_PER_EPOCH"])
308315

309316
if _BACKEND != "tensorflow":
310317
# If we are not running on the tensorflow backend, we should stop tensorflow
@@ -333,3 +340,66 @@ def backend():
333340
334341
"""
335342
return _BACKEND
343+
344+
345+
@keras_export(["keras.config.set_max_epochs"])
346+
def set_max_epochs(max_epochs):
347+
"""Limit the maximum number of epochs for any call to fit.
348+
349+
This will cap the number of epochs for any training run using `model.fit()`.
350+
This is purely for debugging, and can also be set via the `KERAS_MAX_EPOCHS`
351+
environment variable to quickly run a script without modifying its source.
352+
353+
Args:
354+
max_epochs: The integer limit on the number of epochs or `None`. If
355+
`None`, no limit is applied.
356+
"""
357+
global _MAX_EPOCHS
358+
_MAX_EPOCHS = max_epochs
359+
360+
361+
@keras_export(["keras.config.set_max_steps_per_epoch"])
362+
def set_max_steps_per_epoch(max_steps_per_epoch):
363+
"""Limit the maximum number of steps for any call to fit/evaluate/predict.
364+
365+
This will cap the number of steps for single epoch of a call to `fit()`,
366+
`evaluate()`, or `predict()`. This is purely for debugging, and can also be
367+
set via the `KERAS_MAX_STEPS_PER_EPOCH` environment variable to quickly run
368+
a scrip without modifying its source.
369+
370+
Args:
371+
max_epochs: The integer limit on the number of epochs or `None`. If
372+
`None`, no limit is applied.
373+
"""
374+
global _MAX_STEPS_PER_EPOCH
375+
_MAX_STEPS_PER_EPOCH = max_steps_per_epoch
376+
377+
378+
@keras_export(["keras.config.max_epochs"])
379+
def max_epochs():
380+
"""Get the maximum number of epochs for any call to fit.
381+
382+
Retrieves the limit on the number of epochs set by
383+
`keras.config.set_max_epochs` or the `KERAS_MAX_EPOCHS` environment
384+
variable.
385+
386+
Returns:
387+
The integer limit on the number of epochs or `None`, if no limit has
388+
been set.
389+
"""
390+
return _MAX_EPOCHS
391+
392+
393+
@keras_export(["keras.config.max_steps_per_epoch"])
394+
def max_steps_per_epoch():
395+
"""Get the maximum number of steps for any call to fit/evaluate/predict.
396+
397+
Retrieves the limit on the number of epochs set by
398+
`keras.config.set_max_steps_per_epoch` or the `KERAS_MAX_STEPS_PER_EPOCH`
399+
environment variable.
400+
401+
Args:
402+
max_epochs: The integer limit on the number of epochs or `None`. If
403+
`None`, no limit is applied.
404+
"""
405+
return _MAX_STEPS_PER_EPOCH

keras/src/backend/jax/trainer.py

+7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import collections
22
import itertools
3+
import warnings
34
from functools import partial
45

56
import jax
@@ -9,6 +10,7 @@
910
from keras.src import callbacks as callbacks_module
1011
from keras.src import optimizers as optimizers_module
1112
from keras.src import tree
13+
from keras.src.backend import config
1214
from keras.src.backend import distribution_lib as jax_distribution_lib
1315
from keras.src.distribution import distribution_lib
1416
from keras.src.trainers import trainer as base_trainer
@@ -341,6 +343,11 @@ def fit(
341343
validation_freq=1,
342344
):
343345
self._assert_compile_called("fit")
346+
# Possibly cap epochs for debugging runs.
347+
max_epochs = config.max_epochs()
348+
if max_epochs and max_epochs < epochs:
349+
warnings.warn("Limiting epochs to %d" % max_epochs)
350+
epochs = max_epochs
344351
# TODO: respect compiled trainable state
345352
self._eval_epoch_iterator = None
346353
if validation_split and validation_data is None:

keras/src/backend/tensorflow/trainer.py

+6
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from keras.src import metrics as metrics_module
1010
from keras.src import optimizers as optimizers_module
1111
from keras.src import tree
12+
from keras.src.backend import config
1213
from keras.src.losses import loss as loss_module
1314
from keras.src.trainers import trainer as base_trainer
1415
from keras.src.trainers.data_adapters import array_slicing
@@ -309,6 +310,11 @@ def fit(
309310
validation_freq=1,
310311
):
311312
self._assert_compile_called("fit")
313+
# Possibly cap epochs for debugging runs.
314+
max_epochs = config.max_epochs()
315+
if max_epochs and max_epochs < epochs:
316+
warnings.warn("Limiting epochs to %d" % max_epochs)
317+
epochs = max_epochs
312318
# TODO: respect compiled trainable state
313319
self._eval_epoch_iterator = None
314320
if validation_split and validation_data is None:

keras/src/backend/torch/trainer.py

+6
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from keras.src import callbacks as callbacks_module
99
from keras.src import optimizers as optimizers_module
1010
from keras.src import tree
11+
from keras.src.backend import config
1112
from keras.src.trainers import trainer as base_trainer
1213
from keras.src.trainers.data_adapters import array_slicing
1314
from keras.src.trainers.data_adapters import data_adapter_utils
@@ -187,6 +188,11 @@ def fit(
187188
raise ValueError(
188189
"You must call `compile()` before calling `fit()`."
189190
)
191+
# Possibly cap epochs for debugging runs.
192+
max_epochs = config.max_epochs()
193+
if max_epochs and max_epochs < epochs:
194+
warnings.warn("Limiting epochs to %d" % max_epochs)
195+
epochs = max_epochs
190196

191197
# TODO: respect compiled trainable state
192198
self._eval_epoch_iterator = None

keras/src/trainers/epoch_iterator.py

+9
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import contextlib
4343
import warnings
4444

45+
from keras.src.backend import config
4546
from keras.src.trainers import data_adapters
4647

4748

@@ -57,6 +58,14 @@ def __init__(
5758
class_weight=None,
5859
steps_per_execution=1,
5960
):
61+
# Possibly cap steps_per_epoch for debugging runs.
62+
max_steps_per_epoch = config.max_steps_per_epoch()
63+
if max_steps_per_epoch:
64+
if not steps_per_epoch or max_steps_per_epoch < steps_per_epoch:
65+
warnings.warn(
66+
"Limiting steps_per_epoch to %d" % max_steps_per_epoch
67+
)
68+
steps_per_epoch = max_steps_per_epoch
6069
self.steps_per_epoch = steps_per_epoch
6170
self.steps_per_execution = steps_per_execution
6271
self._current_iterator = None

keras/src/trainers/trainer_test.py

+43
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from keras.src import ops
1515
from keras.src import optimizers
1616
from keras.src import testing
17+
from keras.src.backend import config
1718
from keras.src.backend.common.symbolic_scope import in_symbolic_scope
1819
from keras.src.callbacks.callback import Callback
1920
from keras.src.optimizers.rmsprop import RMSprop
@@ -1506,6 +1507,48 @@ def test_steps_per_epoch(self, steps_per_epoch_test, mode):
15061507
)
15071508
self.assertAllClose(model.evaluate(x, y), model_2.evaluate(x, y))
15081509

1510+
@pytest.mark.requires_trainable_backend
1511+
def test_max_epochs_and_steps(self):
1512+
batch_size = 8
1513+
epochs = 4
1514+
num_batches = 10
1515+
data_size = num_batches * batch_size
1516+
x, y = np.ones((data_size, 4)), np.ones((data_size, 1))
1517+
model = ExampleModel(units=1)
1518+
model.compile(
1519+
loss="mse",
1520+
optimizer="sgd",
1521+
metrics=[EpochAgnosticMeanSquaredError()],
1522+
)
1523+
step_observer = StepObserver()
1524+
model.fit(
1525+
x=x,
1526+
y=y,
1527+
batch_size=batch_size,
1528+
epochs=epochs,
1529+
callbacks=[step_observer],
1530+
verbose=0,
1531+
)
1532+
self.assertEqual(step_observer.epoch_begin_count, epochs)
1533+
self.assertEqual(step_observer.begin_count, num_batches * epochs)
1534+
try:
1535+
config.set_max_epochs(2)
1536+
config.set_max_steps_per_epoch(3)
1537+
step_observer = StepObserver()
1538+
model.fit(
1539+
x=x,
1540+
y=y,
1541+
batch_size=batch_size,
1542+
epochs=epochs,
1543+
callbacks=[step_observer],
1544+
verbose=0,
1545+
)
1546+
self.assertEqual(step_observer.epoch_begin_count, 2)
1547+
self.assertEqual(step_observer.begin_count, 6)
1548+
finally:
1549+
config.set_max_epochs(None)
1550+
config.set_max_steps_per_epoch(None)
1551+
15091552
@parameterized.named_parameters(
15101553
named_product(
15111554
steps_per_epoch_test=[

0 commit comments

Comments
 (0)