[pmap] Avoid degraded performance under the new jax.pmap.
#5152
+12
−1
jax.pmap.
#5152