From 35d7e49797398d55adffaac42bcab2265476b5e7 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Fri, 10 Jan 2025 21:54:21 +0000 Subject: [PATCH] skip keys not needed --- torchprime/experimental/torchax_models/run.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchprime/experimental/torchax_models/run.py b/torchprime/experimental/torchax_models/run.py index eefc698..990fc5e 100644 --- a/torchprime/experimental/torchax_models/run.py +++ b/torchprime/experimental/torchax_models/run.py @@ -161,6 +161,8 @@ def create_sharded_weights(model, mesh, sharding_map): def create_weights(rng): res = {} for name, weight_meta in model.state_dict().items(): + if _process_sharding_name(name) not in sharding_map: + continue rng, subkey = jax.random.split(rng) if len(weight_meta.shape) < 2: res[name] = jax.random.normal(subkey, weight_meta.shape,