Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/wmt/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,11 @@ def pre_pmap(xs):
)

def post_pmap(xs):
# Avoid degraded performance under the new jax.pmap. See
# https://docs.jax.dev/en/latest/migrate_pmap.html#int-indexing-into-sharded-arrays.
if jax.config.jax_pmap_shmap_merge:
return jax.tree_util.tree_map(
lambda x: x.addressable_shards[0].data.squeeze(0), xs)
return jax.tree_util.tree_map(lambda x: x[0], xs)

return post_pmap(host_psum(pre_pmap(in_tree)))
Expand Down
8 changes: 7 additions & 1 deletion flax/training/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,13 @@ def get_metrics(device_metrics):
"""
# We select the first element of x in order to get a single copy of a
# device-replicated metric.
device_metrics = jax.tree_util.tree_map(lambda x: x[0], device_metrics)
# Avoid degraded performance under the new jax.pmap. See
# https://docs.jax.dev/en/latest/migrate_pmap.html#int-indexing-into-sharded-arrays.
if jax.config.jax_pmap_shmap_merge:
device_metrics = jax.tree_util.tree_map(
lambda x: x.addressable_shards[0].data.squeeze(0), device_metrics)
else:
device_metrics = jax.tree_util.tree_map(lambda x: x[0], device_metrics)
metrics_np = jax.device_get(device_metrics)
return stack_forest(metrics_np)

Expand Down
Loading