Fix: Pass einsum_str instead of module name to get_axis_to_reduce_from_einsum_str #480
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
Fixes #479
This PR fixes a bug in
SimulateQuantizedEinsumwhereself.wrapped.name(the Flax module name) was incorrectly passed toget_axis_to_reduce_from_einsum_str()instead of the actual einsum equation string. This caused the pattern-specific quantization axis selection logic to never execute, resulting in all einsum operations falling back to generic per-channel scaling instead of using the intended pattern-specific scaling that respects the einsum structure.Changes
gemma/peft/_quantization.pyeinsum_str=self.wrapped.namewitheinsum_str=einsum_strin the call toget_axis_to_reduce_from_einsum_str()The fix is minimal and low-risk, as it only changes which string is passed to an existing function. The
einsum_strvariable is already validated and available in the same function scope (lines 172-175), so this ensures the pattern matching logic inget_axis_to_reduce_from_einsum_str()can correctly identify einsum equations and return the appropriate quantization axes.Testing
A demonstration script (
examples/einsum_quantization_bug_demo.py) was created to verify the bug and confirm the fix. The script shows that:get_axis_to_reduce_from_einsum_str()receives module name and returnsNoneget_axis_to_reduce_from_einsum_str()receives einsum equation and returns the correct axis tupleImpact
This fix enables the pattern-specific quantization axis selection that was intended by the original code authors, improving quantization accuracy for einsum operations in quantization-aware training workflows.