Skip to content

Commit

Permalink
Merge pull request #1102 from AI-Hypercomputer:fhzhang/rampup
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 708497751
  • Loading branch information
maxtext authors committed Dec 21, 2024
2 parents 834b778 + 1df6662 commit 6ec3368
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 2 deletions.
2 changes: 2 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,8 @@ inference_microbenchmark_stages: "prefill,generate"
inference_microbenchmark_loop_iters: 10
inference_microbenchmark_log_file_path: ""
inference_metadata_file: "" # path to a json file
inference_server: "MaxtextInterleavedServer" # inference server to start
inference_benchmark_test: False
enable_model_warmup: False

# Stack prefill cache across the layer to reduce the
Expand Down
5 changes: 4 additions & 1 deletion MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ def maybe_initialize_jax_distributed_system(raw_keys):
For CPUs, we call jax.distributed.initialize() explicitly, with the specified arguments.
"""
if raw_keys["inference_benchmark_test"]:
# Disable initialization for inference benmark test.
return
if raw_keys["compile_topology"]:
# Don't initialize jax distributed with AOT compilation
return
Expand Down Expand Up @@ -531,7 +534,7 @@ def create_device_mesh(config, devices=None):
if devices is None:
devices = jax.devices()
num_devices = len(devices)
num_slices = config.num_slices
num_slices = 1 if config.inference_benchmark_test else config.num_slices
num_devices_per_slice = num_devices // num_slices

multi_slice_env = num_slices > 1
Expand Down
2 changes: 1 addition & 1 deletion MaxText/maxengine_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
def main(config):
# No devices for local cpu test. A None for prefill and a None for generate.
devices = server_lib.get_devices()
server_config = maxengine_config.get_server_config("MaxtextInterleavedServer", config)
server_config = maxengine_config.get_server_config(config.inference_server, config)

metrics_server_config: config_lib.MetricsServerConfig | None = None
if config.prometheus_port != 0:
Expand Down

0 comments on commit 6ec3368

Please sign in to comment.