Skip to content

Conversation

@cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Nov 25, 2025

Summary of Changes

This PR introduces significant refactoring to flax.nnx to standardize variable metadata, improve Hijax support, and consolidate APIs. It also includes fixes for type hinting and documentation.

Key Changes

  • Variable Metadata Standardization:

    • Renamed metadata keys for consistency:
      • is_hijaxhijax
      • has_refref
      • is_mutablemutable
  • New var_defaults API:

    • Introduced nnx.var_defaults() to replace nnx.use_hijax() and nnx.using_hijax().
    • This unified API serves as:
      • A context manager: with nnx.var_defaults(hijax=True): ...
      • A decorator: @nnx.var_defaults(hijax=True)
      • A configuration accessor: defaults = nnx.var_defaults()
  • Unified Variable Conversion API:

    • Consolidated multiple conversion functions into a single nnx.vars_as() function:
      • nnx.as_ref_vars(...)nnx.vars_as(..., ref=True)
      • nnx.as_immutable_vars(...)nnx.vars_as(..., mutable=False)
      • nnx.as_mutable_vars(...)nnx.vars_as(..., mutable=True)
      • nnx.as_hijax_vars(...)nnx.vars_as(..., hijax=True)

Fixes & Improvements

  • Type Hinting (Mypy):
    • Resolved Ref redefinition errors in flax/nnx/variablelib.py.
    • Fixed @tp.overload value shadowing for var_defaults.
    • Corrected property usage on non-method attributes.
  • Documentation:
    • Fixed indentation in recursive_map docstring in flax/nnx/graph.py.
    • Fixed RST transition syntax error in docs_nnx/hijax/index.rst.

Testing

  • Updated integration and unit tests (integration_test.py, mutable_array_test.py, spmd_test.py, variable_test.py) to utilize the new vars_as and var_defaults APIs.
  • Added comprehensive tests for the var_defaults context manager, including nesting behavior.
  • Verified changes with ./tests/run_all_tests.sh --only-mypy and ./tests/run_all_tests.sh --only-doctest.

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@cgarciae cgarciae force-pushed the hijax-guide branch 2 times, most recently from 0520c58 to a7f8af9 Compare December 2, 2025 22:03
@cgarciae cgarciae force-pushed the hijax-guide branch 3 times, most recently from 8ce1c26 to e99e55f Compare December 12, 2025 16:02
@cgarciae cgarciae marked this pull request as ready for review December 12, 2025 16:02
@cgarciae cgarciae force-pushed the hijax-guide branch 4 times, most recently from 9cef051 to d68e156 Compare December 16, 2025 15:55
@copybara-service copybara-service bot merged commit e8fe693 into main Jan 6, 2026
24 checks passed
@copybara-service copybara-service bot deleted the hijax-guide branch January 6, 2026 21:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants