-
Notifications
You must be signed in to change notification settings - Fork 618
Description
Description
SimulateQuantizedEinsum.__call__ incorrectly passes self.wrapped.name (the Flax module name) to get_axis_to_reduce_from_einsum_str() instead of the actual einsum equation string. This causes the pattern-specific quantization axis selection logic to never execute, resulting in all einsum operations falling back to generic per-channel scaling instead of using the intended pattern-specific scaling.
Problem
In gemma/peft/_quantization.py, the SimulateQuantizedEinsum wrapper is designed to use pattern-specific quantization axes for different einsum equations. The function get_axis_to_reduce_from_einsum_str() contains explicit pattern matching for common einsum equations like "BTD,NDH->BTNH", "...H,HF->...F", etc., and returns the appropriate axis tuple for quantization scaling.
However, on line 192, the code passes self.wrapped.name (the module name, e.g., "einsum_0" or "attention_proj") instead of the actual einsum_str:
ython
kernel = simulate_quantize(
kernel,
self.method,
axis_to_reduce=get_axis_to_reduce_from_einsum_str(
einsum_str=self.wrapped.name # ❌ BUG: Should be einsum_str
),
)Since get_axis_to_reduce_from_einsum_str() only matches on einsum equation strings, passing a module name means it always falls through to the default case and returns None. This causes simulate_quantize() to use generic per-channel scaling along the last axis, ignoring the einsum structure.
Impact
-
Dead Code: The pattern-specific logic in
get_axis_to_reduce_from_einsum_str()is never executed, making it effectively dead code. -
Incorrect Quantization: All einsum operations use generic per-channel scaling instead of pattern-specific scaling, leading to suboptimal quantization accuracy.
-
Affects All Einsum Operations: Any einsum wrapped by
SimulateQuantizedEinsumis affected, regardless of the equation pattern. -
Silent Failure: The code runs without errors, but quantization accuracy is degraded, making this bug difficult to detect.
Steps to Reproduce
-
Create a
SimulateQuantizedEinsuminstance with a known einsum equation:from gemma.peft import _quantization
from gemma.peft import _quantization_utils
from flax import linen as nnwrapped = nn.Einsum(einsum_str="BTD,NDH->BTNH", shape=(4, 8, 16), name="attention_proj")
quantized = _quantization.SimulateQuantizedEinsum(
wrapped=wrapped,
method=_quantization_utils.QuantizationMethod.INT4
)
2. Intercept calls toget_axis_to_reduce_from_einsum_str()to see what argument is passed. -
Observe that the function receives
"attention_proj"(or"wrapped") instead of"BTD,NDH->BTNH". -
Verify that
get_axis_to_reduce_from_einsum_str("attention_proj")returnsNoneinstead of the expected(1,).
Expected Behavior
When SimulateQuantizedEinsum.__call__ is invoked:
- It should retrieve the einsum equation string from
self.wrapped.einsum_str(which is already done on lines 172-175). - It should pass this equation string to
get_axis_to_reduce_from_einsum_str(). - The function should match the pattern and return the appropriate axis tuple (e.g.,
(1,)for"BTD,NDH->BTNH"). - Quantization should use pattern-specific scaling that respects the einsum structure.
Actual Behavior
- The code passes
self.wrapped.name(module name) toget_axis_to_reduce_from_einsum_str(). - The function doesn't match any pattern and returns
None. - Quantization falls back to generic per-channel scaling along the last axis.
- Pattern-specific scaling is never used, reducing quantization accuracy.
Proposed Solution
Change line 192 in gemma/peft/_quantization.py from:
axis_to_reduce=get_axis_to_reduce_from_einsum_str(
einsum_str=self.wrapped.name
),to:
axis_to_reduce=get_axis_to_reduce_from_einsum_str(
einsum_str=einsum_str
),The einsum_str variable is already validated and available in the same function scope (lines 172-175), so this is a simple one-line fix.
Additional Context
This bug affects quantization-aware training (QAT) workflows where einsum operations are quantized. The pattern-specific axis selection was clearly intended by the code authors (as evidenced by the explicit pattern matching in get_axis_to_reduce_from_einsum_str()), but the bug prevents this logic from ever executing.
The fix is minimal and low-risk, as it only changes which string is passed to an existing function that already handles the einsum equation strings correctly. Therefore, i will also attach a PR immediately after i post this issue.
