jax adventures 001

Published

June 2, 2025

I have been wanting to learn more Jax, so I will go on an a Jax adventure.

import jax
import jax.numpy as jnp
import flax.linen as nn
import matplotlib.pyplot as plt
import numpy as np

These imports should be sufficient for what I’d need.

I saw there is some kind of selfattention module in flax. Apparently Flax Linen is being taken over by Flax NNX

from flax import nnx
my_selfattention = nnx.MultiHeadAttention(
  num_heads=2,
  in_features=10,
  qkv_features=6,
  out_features=12,
  decode=False,
  rngs=nnx.Rngs(0)
)

Okay now I have this object, I assume it is a nnx module. Usually then it has to be initialized with parameters separately, but why is there then an rngs associated with it?

print(my_selfattention)
MultiHeadAttention( # Param: 282 (1.1 KB)

  num_heads=2,

  in_features=10,

  qkv_features=6,

  out_features=12,

  dtype=None,

  param_dtype=float32,

  broadcast_dropout=True,

  dropout_rate=0.0,

  deterministic=None,

  precision=None,

  kernel_init=<function variance_scaling.<locals>.init at 0x73e5c28e2700>,

  out_kernel_init=None,

  bias_init=<function zeros at 0x73e5c362c5e0>,

  out_bias_init=None,

  use_bias=True,

  attention_fn=<function dot_product_attention at 0x73e5c28e2200>,

  decode=False,

  normalize_qk=False,

  qkv_dot_general=None,

  out_dot_general=None,

  qkv_dot_general_cls=None,

  out_dot_general_cls=None,

  head_dim=3,

  query=LinearGeneral( # Param: 66 (264 B)

    in_features=(10,),

    out_features=(2, 3),

    axis=(-1,),

    batch_axis=FrozenDict({}),

    use_bias=True,

    dtype=None,

    param_dtype=float32,

    kernel_init=<function variance_scaling.<locals>.init at 0x73e5c28e2700>,

    bias_init=<function zeros at 0x73e5c362c5e0>,

    precision=None,

    dot_general=None,

    dot_general_cls=None,

    promote_dtype=<function promote_dtype at 0x73e5c28e2160>,

    kernel=Param( # 60 (240 B)

      value=Array(shape=(10, 2, 3), dtype=dtype('float32'))

    ),

    bias=Param( # 6 (24 B)

      value=Array(shape=(2, 3), dtype=dtype('float32'))

    )

  ),

  key=LinearGeneral( # Param: 66 (264 B)

    in_features=(10,),

    out_features=(2, 3),

    axis=(-1,),

    batch_axis=FrozenDict({}),

    use_bias=True,

    dtype=None,

    param_dtype=float32,

    kernel_init=<function variance_scaling.<locals>.init at 0x73e5c28e2700>,

    bias_init=<function zeros at 0x73e5c362c5e0>,

    precision=None,

    dot_general=None,

    dot_general_cls=None,

    promote_dtype=<function promote_dtype at 0x73e5c28e2160>,

    kernel=Param( # 60 (240 B)

      value=Array(shape=(10, 2, 3), dtype=dtype('float32'))

    ),

    bias=Param( # 6 (24 B)

      value=Array(shape=(2, 3), dtype=dtype('float32'))

    )

  ),

  value=LinearGeneral( # Param: 66 (264 B)

    in_features=(10,),

    out_features=(2, 3),

    axis=(-1,),

    batch_axis=FrozenDict({}),

    use_bias=True,

    dtype=None,

    param_dtype=float32,

    kernel_init=<function variance_scaling.<locals>.init at 0x73e5c28e2700>,

    bias_init=<function zeros at 0x73e5c362c5e0>,

    precision=None,

    dot_general=None,

    dot_general_cls=None,

    promote_dtype=<function promote_dtype at 0x73e5c28e2160>,

    kernel=Param( # 60 (240 B)

      value=Array(shape=(10, 2, 3), dtype=dtype('float32'))

    ),

    bias=Param( # 6 (24 B)

      value=Array(shape=(2, 3), dtype=dtype('float32'))

    )

  ),

  query_ln=None,

  key_ln=None,

  out=LinearGeneral( # Param: 84 (336 B)

    in_features=(2, 3),

    out_features=(12,),

    axis=(-2, -1),

    batch_axis=FrozenDict({}),

    use_bias=True,

    dtype=None,

    param_dtype=float32,

    kernel_init=<function variance_scaling.<locals>.init at 0x73e5c28e2700>,

    bias_init=<function zeros at 0x73e5c362c5e0>,

    precision=None,

    dot_general=None,

    dot_general_cls=None,

    promote_dtype=<function promote_dtype at 0x73e5c28e2160>,

    kernel=Param( # 72 (288 B)

      value=Array(shape=(2, 3, 12), dtype=dtype('float32'))

    ),

    bias=Param( # 12 (48 B)

      value=Array(shape=(12,), dtype=dtype('float32'))

    )

  ),

  rngs=None,

  cached_key=None,

  cached_value=None,

  cache_index=None

)