Skip to content

Fix: SimulateQuantizedEinsum passes module name instead of einsum_str to get_axis_to_reduce_from_einsum_str #479

@markknoffler

Description

@markknoffler

Description

SimulateQuantizedEinsum.__call__ incorrectly passes self.wrapped.name (the Flax module name) to get_axis_to_reduce_from_einsum_str() instead of the actual einsum equation string. This causes the pattern-specific quantization axis selection logic to never execute, resulting in all einsum operations falling back to generic per-channel scaling instead of using the intended pattern-specific scaling.

Problem

In gemma/peft/_quantization.py, the SimulateQuantizedEinsum wrapper is designed to use pattern-specific quantization axes for different einsum equations. The function get_axis_to_reduce_from_einsum_str() contains explicit pattern matching for common einsum equations like "BTD,NDH->BTNH", "...H,HF->...F", etc., and returns the appropriate axis tuple for quantization scaling.

However, on line 192, the code passes self.wrapped.name (the module name, e.g., "einsum_0" or "attention_proj") instead of the actual einsum_str:
ython
kernel = simulate_quantize(
kernel,
self.method,
axis_to_reduce=get_axis_to_reduce_from_einsum_str(
einsum_str=self.wrapped.name # ❌ BUG: Should be einsum_str
),
)Since get_axis_to_reduce_from_einsum_str() only matches on einsum equation strings, passing a module name means it always falls through to the default case and returns None. This causes simulate_quantize() to use generic per-channel scaling along the last axis, ignoring the einsum structure.

Impact

  1. Dead Code: The pattern-specific logic in get_axis_to_reduce_from_einsum_str() is never executed, making it effectively dead code.

  2. Incorrect Quantization: All einsum operations use generic per-channel scaling instead of pattern-specific scaling, leading to suboptimal quantization accuracy.

  3. Affects All Einsum Operations: Any einsum wrapped by SimulateQuantizedEinsum is affected, regardless of the equation pattern.

  4. Silent Failure: The code runs without errors, but quantization accuracy is degraded, making this bug difficult to detect.

Steps to Reproduce

  1. Create a SimulateQuantizedEinsum instance with a known einsum equation:

    from gemma.peft import _quantization
    from gemma.peft import _quantization_utils
    from flax import linen as nn

    wrapped = nn.Einsum(einsum_str="BTD,NDH->BTNH", shape=(4, 8, 16), name="attention_proj")
    quantized = _quantization.SimulateQuantizedEinsum(
    wrapped=wrapped,
    method=_quantization_utils.QuantizationMethod.INT4
    )
    2. Intercept calls to get_axis_to_reduce_from_einsum_str() to see what argument is passed.

  2. Observe that the function receives "attention_proj" (or "wrapped") instead of "BTD,NDH->BTNH".

  3. Verify that get_axis_to_reduce_from_einsum_str("attention_proj") returns None instead of the expected (1,).

Expected Behavior

When SimulateQuantizedEinsum.__call__ is invoked:

  1. It should retrieve the einsum equation string from self.wrapped.einsum_str (which is already done on lines 172-175).
  2. It should pass this equation string to get_axis_to_reduce_from_einsum_str().
  3. The function should match the pattern and return the appropriate axis tuple (e.g., (1,) for "BTD,NDH->BTNH").
  4. Quantization should use pattern-specific scaling that respects the einsum structure.

Actual Behavior

  1. The code passes self.wrapped.name (module name) to get_axis_to_reduce_from_einsum_str().
  2. The function doesn't match any pattern and returns None.
  3. Quantization falls back to generic per-channel scaling along the last axis.
  4. Pattern-specific scaling is never used, reducing quantization accuracy.

Proposed Solution

Change line 192 in gemma/peft/_quantization.py from:

axis_to_reduce=get_axis_to_reduce_from_einsum_str(
einsum_str=self.wrapped.name
),to:

axis_to_reduce=get_axis_to_reduce_from_einsum_str(
einsum_str=einsum_str
),The einsum_str variable is already validated and available in the same function scope (lines 172-175), so this is a simple one-line fix.

Additional Context

This bug affects quantization-aware training (QAT) workflows where einsum operations are quantized. The pattern-specific axis selection was clearly intended by the code authors (as evidenced by the explicit pattern matching in get_axis_to_reduce_from_einsum_str()), but the bug prevents this logic from ever executing.

The fix is minimal and low-risk, as it only changes which string is passed to an existing function that already handles the einsum equation strings correctly. Therefore, i will also attach a PR immediately after i post this issue.

Image Image

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions