Skip to content

Conversation

@markknoffler
Copy link

Summary

Fixes #479

This PR fixes a bug in SimulateQuantizedEinsum where self.wrapped.name (the Flax module name) was incorrectly passed to get_axis_to_reduce_from_einsum_str() instead of the actual einsum equation string. This caused the pattern-specific quantization axis selection logic to never execute, resulting in all einsum operations falling back to generic per-channel scaling instead of using the intended pattern-specific scaling that respects the einsum structure.

Changes

  • File: gemma/peft/_quantization.py
  • Line: 192
  • Change: Replaced einsum_str=self.wrapped.name with einsum_str=einsum_str in the call to get_axis_to_reduce_from_einsum_str()

The fix is minimal and low-risk, as it only changes which string is passed to an existing function. The einsum_str variable is already validated and available in the same function scope (lines 172-175), so this ensures the pattern matching logic in get_axis_to_reduce_from_einsum_str() can correctly identify einsum equations and return the appropriate quantization axes.

Testing

A demonstration script (examples/einsum_quantization_bug_demo.py) was created to verify the bug and confirm the fix. The script shows that:

  • Before fix: get_axis_to_reduce_from_einsum_str() receives module name and returns None
  • After fix: get_axis_to_reduce_from_einsum_str() receives einsum equation and returns the correct axis tuple

Impact

This fix enables the pattern-specific quantization axis selection that was intended by the original code authors, improving quantization accuracy for einsum operations in quantization-aware training workflows.

Screenshot 2025-12-28 at 12 57 33 PM

@google-cla
Copy link

google-cla bot commented Dec 28, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Fix: SimulateQuantizedEinsum passes module name instead of einsum_str to get_axis_to_reduce_from_einsum_str

1 participant