이 글을 읽기 앞서 Flax의 기초로 모든 Flax Module은 dataclass(Python 3.7이상)를 사용한다. 그래서 __init__을 사용하지 않으며, 사용해야한다면 setup()함수를 이용해야 한다. 그렇지 않은 경우에는 @nn.compact__call__()함수 위에 적어주고 사용하면 된다. 이 경우에는 __call__()에서 모델의 모든 부분을 정의해주면 되고, Flax의 경우는 후자의 방식을 권장한다.

보통 nn이라고 요약하여 사용하며, PyTorch처럼 Model class에서 nn.Module을 상속받아 사용한다.

from flax import linen as nn

class Module(nn.Module):
  features: Tuple[int, ...] = (16, 4)

  def setup(self):
    self.dense1 = Dense(self.features[0])
    self.dense2 = Dense(self.features[1])

  def __call__(self, x):
    return self.dense2(nn.relu(self.dense1(x)))

apply

일반적으로 apply 함수는 다음과 같은 상황에서 사용된다:

  1. 모델의 가중치(weight)를 업데이트하기 위해, 그라디언트(gradient)를 적용할 때
  2. 배치 정규화(batch normalization) 계층에서, 평균과 분산을 업데이트할 때
  3. 데이터셋에서 데이터를 처리하거나, 데이터셋에 대한 통계를 계산할 때

예를 들어, 다음은 sgd를 사용하여 가중치를 업데이트하는 예시 코드다.

import jax
from jax import random
from flax import linen as nn

# define a simple model
class MLP(nn.Module):
  features: int
  
  @nn.compact
  def __call__(self, x):
    x = nn.Dense(features=self.features)(x)
    x = nn.relu(x)
    x = dense2 = nn.Dense(features=1)(x)
    return x

# create an instance of the model
key = random.PRNGKey(0)
model = MLP(features=32).init(key, jnp.ones((1, 10)))

# define a function to apply to the model
def set_bias_to_zero(module):
  for name, param in module.params.items():
    if name.endswith('b'):
      param = jnp.zeros_like(param)
  return module.replace(params=module.params)

# apply the function to the model using the apply method
model = model.apply(set_bias_to_zero)

모델을 초기화하고 apply 함수를 사용하여 set_bias_to_zero 함수를 모델에 적용한다. 이렇게 하면 모델의 모든 편향 가중치가 0으로 설정된다.

이 예제에서는 set_bias_to_zero 함수가 직접적으로 MLP 모듈에 적용되지만, 더 복잡한 모델의 경우, apply 함수를 사용하여 모든 하위 모듈에 적용할 수 있다.

bind

bind 함수는 함수 인자를 바인딩(binding)하는 데 사용됩니다. 바인딩함수 호출에서 일부 인자의 값을 미리 설정하는 것을 의미한다. 즉, bind 함수를 사용하면 일부 인자 값을 미리 설정한 새로운 함수를 생성할 수 있다(functools의 partial과 비슷하다고 생각하면 이해하기 쉽다).

bind 함수는 다음과 같은 구문으로 사용된다.

from flax import linen as nn

# define a Dense module with features argument
class MyDense(nn.Module):
  features: int

  @nn.compact
  def __call__(self, x):
    return nn.Dense(features=self.features)(x)

# bind the MyDense module with features=64
dense_64 = nn.bind(MyDense, features=64)

# create an instance of the new module and apply it to an input array
x = jnp.array([1.0, 2.0, 3.0])
y = dense_64().apply(x)
print(y)  # prints "[0.00240129, -0.00455584, 0.00039986]"

flax.linen.bind 함수를 사용하여 features 인자 값을 64로 설정한 새로운 모듈 dense_64을 생성한다. 그런 다음, dense_64를 호출하여 새로운 모듈의 인스턴스를 생성하고, apply 함수를 사용하여 입력 배열 x에 모듈을 적용한다.

initializers

딥러닝 모델에서는 일반적으로 가중치 초기화 함수를 사용하여 모델 가중치를 초기화하고, Flax에서 이 함수는 flax.linen.initializers 모듈에서 제공된다.

from flax import linen as nn

class MyDense(nn.Module):
  features: int
  kernel_init: Callable = nn.initializers.lecun_normal()
  bias_init: Callable = nn.initializers.zeros

  @nn.compact
  def __call__(self, x):
    y = nn.Dense(features=self.features, kernel_init=self.kernel_init, bias_init=self.bias_init)(x)
    return y

nn.initializers.lecun_normal() 함수를 사용하여 kernel_init 가중치 초기화 함수를 생성하였다. bias_initnn.initializers.zeros 함수를 사용하여 초기화 함수를 생성하였다. 그런 다음, Dense 모듈에 kernel_initbias_init 인자로 생성한 초기화 함수를 전달한다.

MyDense 클래스는 이제 kernel_initbias_init 인자로 전달한 가중치 초기화 함수를 사용하여 모델 가중치를 초기화할 수 있다.

init

flax.linen.init() 함수를 사용하여 모델 인스턴스의 가중치를 초기화할 수 있다. init() 함수는 모델 인스턴스를 인자로 받아 초기화된 모델 인스턴스를 반환한다.

예를 들어, Dense 모듈로 구성된 간단한 신경망 모델을 만들고 이를 초기화하려면 다음과 같이 할 수 있다.

import jax
import jax.numpy as jnp
from flax import linen as nn

class MLP(nn.Module):
    hidden_size: int
    output_size: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.hidden_size)(x)
        x = nn.relu(x)
        x = nn.Dense(self.output_size)(x)
        return x

model = MLP(hidden_size=64, output_size=10)

# 가중치를 초기화합니다.
rng = jax.random.PRNGKey(0)
model = nn.init(model, rng)

위 코드에서 MLP 모델 클래스를 정의하고 __call__ 메서드 안에서 Dense 모듈을 사용하여 모델을 구성한다. 이후 모델 인스턴스를 생성하고, init() 함수를 사용하여 모델의 가중치를 초기화한다. init() 함수는 무작위 시드를 사용하여 가중치를 초기화하고, 초기화된 가중치가 포함된 모델 인스턴스를 반환한다.

init_with_output

이 함수를 사용하여 모델 인스턴스의 가중치를 초기화하고 모델의 출력도 함께 계산할 수 있다. init_with_output() 함수는 flax.linen.init() 함수와 유사하게 작동하지만, 초기화된 모델을 사용하여 입력 데이터를 전달하고 모델의 출력을 계산한다.

import jax
import jax.numpy as jnp
from flax import linen as nn
from jax.experimental import stax

# MLP 모델 정의
class MLP(nn.Module):
    hidden_size: int
    output_size: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.hidden_size)(x)
        x = nn.relu(x)
        x = nn.Dense(self.output_size)(x)
        return x

model = MLP(hidden_size=64, output_size=10)

# 가중치 및 출력 초기화
rng = jax.random.PRNGKey(0)
x = jnp.ones((1, 784))
params, out = nn.init_with_output(rng, model, x)

print("Initialized parameters:", params)
print("Initialized output:", out)
Initialized parameters: FrozenDict({
    Dense_0: {
        'bias': DeviceArray([0., 0., ..., 0.], dtype=float32),
        'kernel': DeviceArray([[ 0.03351291, -0.01413143, ...,  0.05929277, -0.05504522],
                               [ 0.04876595,  0.01508847, ...,  0.02504633, -0.03403518],
                               ...,
                               [ 0.04460957, -0.02312812, ...,  0.0300643 , -0.00740753],
                               [-0.04049913, -0.0521089 , ...,  0.05811187, -0.05726928]], dtype=float32),
    },
    Dense_1: {
        'bias': DeviceArray([0., 0., ..., 0.], dtype=float32),
        'kernel': DeviceArray([[-0.07551454, -0.08685692, ...,  0.09087151, -0.00984192],
                               [-0.06961666, -0.06689157, ...,  0.00751706, -0.08733141],
                               ...,
                               [-0.0646823 , -0.04226141, ...,  0.05052247, -0.03290445],
                               [-0.03011085,  0.08997382, ...,  0.08751475, -0.0467639 ]], dtype=float32),
    },
})
Initialized output: [[-0.48999336 -0.1978706   0.01582504  0.09210121  0.11780421 -0.07384747
  -0.10397934 -0.15517813 -0.05804203 -0.12006739]]

params는 FrozenDict 형식으로 모델의 가중치를 포함하고 있다. 이전에 언급한 것처럼, params는 모델 인스턴스의 매개변수로 사용된다. out은 x를 사용하여 모델의 출력을 계산한 결과이다.