import jax
import jax.numpy as jnp
import flax.linen as nn
import matplotlib.pyplot as plt
import numpy as np
I have been wanting to learn more Jax, so I will go on an a Jax adventure.
These imports should be sufficient for what I’d need.
I saw there is some kind of selfattention module in flax. Apparently Flax Linen is being taken over by Flax NNX
from flax import nnx
= nnx.MultiHeadAttention(
my_selfattention =2,
num_heads=10,
in_features=6,
qkv_features=12,
out_features=False,
decode=nnx.Rngs(0)
rngs )
Okay now I have this object, I assume it is a nnx module. Usually then it has to be initialized with parameters separately, but why is there then an rngs associated with it?
print(my_selfattention)
MultiHeadAttention( # Param: 282 (1.1 KB) num_heads=2, in_features=10, qkv_features=6, out_features=12, dtype=None, param_dtype=float32, broadcast_dropout=True, dropout_rate=0.0, deterministic=None, precision=None, kernel_init=<function variance_scaling.<locals>.init at 0x73e5c28e2700>, out_kernel_init=None, bias_init=<function zeros at 0x73e5c362c5e0>, out_bias_init=None, use_bias=True, attention_fn=<function dot_product_attention at 0x73e5c28e2200>, decode=False, normalize_qk=False, qkv_dot_general=None, out_dot_general=None, qkv_dot_general_cls=None, out_dot_general_cls=None, head_dim=3, query=LinearGeneral( # Param: 66 (264 B) in_features=(10,), out_features=(2, 3), axis=(-1,), batch_axis=FrozenDict({}), use_bias=True, dtype=None, param_dtype=float32, kernel_init=<function variance_scaling.<locals>.init at 0x73e5c28e2700>, bias_init=<function zeros at 0x73e5c362c5e0>, precision=None, dot_general=None, dot_general_cls=None, promote_dtype=<function promote_dtype at 0x73e5c28e2160>, kernel=Param( # 60 (240 B) value=Array(shape=(10, 2, 3), dtype=dtype('float32')) ), bias=Param( # 6 (24 B) value=Array(shape=(2, 3), dtype=dtype('float32')) ) ), key=LinearGeneral( # Param: 66 (264 B) in_features=(10,), out_features=(2, 3), axis=(-1,), batch_axis=FrozenDict({}), use_bias=True, dtype=None, param_dtype=float32, kernel_init=<function variance_scaling.<locals>.init at 0x73e5c28e2700>, bias_init=<function zeros at 0x73e5c362c5e0>, precision=None, dot_general=None, dot_general_cls=None, promote_dtype=<function promote_dtype at 0x73e5c28e2160>, kernel=Param( # 60 (240 B) value=Array(shape=(10, 2, 3), dtype=dtype('float32')) ), bias=Param( # 6 (24 B) value=Array(shape=(2, 3), dtype=dtype('float32')) ) ), value=LinearGeneral( # Param: 66 (264 B) in_features=(10,), out_features=(2, 3), axis=(-1,), batch_axis=FrozenDict({}), use_bias=True, dtype=None, param_dtype=float32, kernel_init=<function variance_scaling.<locals>.init at 0x73e5c28e2700>, bias_init=<function zeros at 0x73e5c362c5e0>, precision=None, dot_general=None, dot_general_cls=None, promote_dtype=<function promote_dtype at 0x73e5c28e2160>, kernel=Param( # 60 (240 B) value=Array(shape=(10, 2, 3), dtype=dtype('float32')) ), bias=Param( # 6 (24 B) value=Array(shape=(2, 3), dtype=dtype('float32')) ) ), query_ln=None, key_ln=None, out=LinearGeneral( # Param: 84 (336 B) in_features=(2, 3), out_features=(12,), axis=(-2, -1), batch_axis=FrozenDict({}), use_bias=True, dtype=None, param_dtype=float32, kernel_init=<function variance_scaling.<locals>.init at 0x73e5c28e2700>, bias_init=<function zeros at 0x73e5c362c5e0>, precision=None, dot_general=None, dot_general_cls=None, promote_dtype=<function promote_dtype at 0x73e5c28e2160>, kernel=Param( # 72 (288 B) value=Array(shape=(2, 3, 12), dtype=dtype('float32')) ), bias=Param( # 12 (48 B) value=Array(shape=(12,), dtype=dtype('float32')) ) ), rngs=None, cached_key=None, cached_value=None, cache_index=None )