You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently, we pass dummy sharding (replicated on all devices) to our inputs, as jax needs to be sharding aware when executing a graph. The proper way to do this is to get this information from the StableHLO compile through protobuf or xla::CompileOptions, which would require linking the whole xla repo. There needs to be more exploration on this topic.
The text was updated successfully, but these errors were encountered:
Currently, we pass dummy sharding (replicated on all devices) to our inputs, as jax needs to be sharding aware when executing a graph. The proper way to do this is to get this information from the StableHLO compile through protobuf or
xla::CompileOptions
, which would require linking the whole xla repo. There needs to be more exploration on this topic.The text was updated successfully, but these errors were encountered: