|
15 | 15 | # Default backend: TensorFlow.
|
16 | 16 | _BACKEND = "tensorflow"
|
17 | 17 |
|
| 18 | +# Cap run duration for debugging. |
| 19 | +_MAX_EPOCHS = None |
| 20 | +_MAX_STEPS_PER_EPOCH = None |
| 21 | + |
18 | 22 |
|
19 | 23 | @keras_export(["keras.config.floatx", "keras.backend.floatx"])
|
20 | 24 | def floatx():
|
@@ -304,7 +308,10 @@ def keras_home():
|
304 | 308 | _backend = os.environ["KERAS_BACKEND"]
|
305 | 309 | if _backend:
|
306 | 310 | _BACKEND = _backend
|
307 |
| - |
| 311 | +if "KERAS_MAX_EPOCHS" in os.environ: |
| 312 | + _MAX_EPOCHS = int(os.environ["KERAS_MAX_EPOCHS"]) |
| 313 | +if "KERAS_MAX_STEPS_PER_EPOCH" in os.environ: |
| 314 | + _MAX_STEPS_PER_EPOCH = int(os.environ["KERAS_MAX_STEPS_PER_EPOCH"]) |
308 | 315 |
|
309 | 316 | if _BACKEND != "tensorflow":
|
310 | 317 | # If we are not running on the tensorflow backend, we should stop tensorflow
|
@@ -333,3 +340,66 @@ def backend():
|
333 | 340 |
|
334 | 341 | """
|
335 | 342 | return _BACKEND
|
| 343 | + |
| 344 | + |
| 345 | +@keras_export(["keras.config.set_max_epochs"]) |
| 346 | +def set_max_epochs(max_epochs): |
| 347 | + """Limit the maximum number of epochs for any call to fit. |
| 348 | +
|
| 349 | + This will cap the number of epochs for any training run using `model.fit()`. |
| 350 | + This is purely for debugging, and can also be set via the `KERAS_MAX_EPOCHS` |
| 351 | + environment variable to quickly run a script without modifying its source. |
| 352 | +
|
| 353 | + Args: |
| 354 | + max_epochs: The integer limit on the number of epochs or `None`. If |
| 355 | + `None`, no limit is applied. |
| 356 | + """ |
| 357 | + global _MAX_EPOCHS |
| 358 | + _MAX_EPOCHS = max_epochs |
| 359 | + |
| 360 | + |
| 361 | +@keras_export(["keras.config.set_max_steps_per_epoch"]) |
| 362 | +def set_max_steps_per_epoch(max_steps_per_epoch): |
| 363 | + """Limit the maximum number of steps for any call to fit/evaluate/predict. |
| 364 | +
|
| 365 | + This will cap the number of steps for single epoch of a call to `fit()`, |
| 366 | + `evaluate()`, or `predict()`. This is purely for debugging, and can also be |
| 367 | + set via the `KERAS_MAX_STEPS_PER_EPOCH` environment variable to quickly run |
| 368 | + a scrip without modifying its source. |
| 369 | +
|
| 370 | + Args: |
| 371 | + max_epochs: The integer limit on the number of epochs or `None`. If |
| 372 | + `None`, no limit is applied. |
| 373 | + """ |
| 374 | + global _MAX_STEPS_PER_EPOCH |
| 375 | + _MAX_STEPS_PER_EPOCH = max_steps_per_epoch |
| 376 | + |
| 377 | + |
| 378 | +@keras_export(["keras.config.max_epochs"]) |
| 379 | +def max_epochs(): |
| 380 | + """Get the maximum number of epochs for any call to fit. |
| 381 | +
|
| 382 | + Retrieves the limit on the number of epochs set by |
| 383 | + `keras.config.set_max_epochs` or the `KERAS_MAX_EPOCHS` environment |
| 384 | + variable. |
| 385 | +
|
| 386 | + Returns: |
| 387 | + The integer limit on the number of epochs or `None`, if no limit has |
| 388 | + been set. |
| 389 | + """ |
| 390 | + return _MAX_EPOCHS |
| 391 | + |
| 392 | + |
| 393 | +@keras_export(["keras.config.max_steps_per_epoch"]) |
| 394 | +def max_steps_per_epoch(): |
| 395 | + """Get the maximum number of steps for any call to fit/evaluate/predict. |
| 396 | +
|
| 397 | + Retrieves the limit on the number of epochs set by |
| 398 | + `keras.config.set_max_steps_per_epoch` or the `KERAS_MAX_STEPS_PER_EPOCH` |
| 399 | + environment variable. |
| 400 | +
|
| 401 | + Args: |
| 402 | + max_epochs: The integer limit on the number of epochs or `None`. If |
| 403 | + `None`, no limit is applied. |
| 404 | + """ |
| 405 | + return _MAX_STEPS_PER_EPOCH |
0 commit comments