Flax lenen 오류 모음
이 글을 읽기 앞서 Flax의 기초로 모든 Flax Module은 dataclass(Python 3.7이상)를 사용한다. 그래서 __init__
을 사용하지 않으며, 사용해야한다면 setup()
함수를 이용해야 한다. 그렇지 않은 경우에는 @nn.compact
를 __call__()
함수 위에 적어주고 사용하면 된다. 이 경우에는 __call__()
에서 모델의 모든 부분을 정의해주면 되고, Flax의 경우는 후자의 방식을 권장한다.
보통 nn
이라고 요약하여 사용하며, PyTorch처럼 Model class에서 nn.Module
을 상속받아 사용한다.
from flax import linen as nn
class Module(nn.Module):
features: Tuple[int, ...] = (16, 4)
def setup(self):
self.dense1 = Dense(self.features[0])
self.dense2 = Dense(self.features[1])
def __call__(self, x):
return self.dense2(nn.relu(self.dense1(x)))
apply
일반적으로 apply 함수는 다음과 같은 상황에서 사용된다:
- 모델의 가중치(weight)를 업데이트하기 위해, 그라디언트(gradient)를 적용할 때
- 배치 정규화(batch normalization) 계층에서, 평균과 분산을 업데이트할 때
- 데이터셋에서 데이터를 처리하거나, 데이터셋에 대한 통계를 계산할 때
예를 들어, 다음은 sgd를 사용하여 가중치를 업데이트하는 예시 코드다.
import jax
from jax import random
from flax import linen as nn
# define a simple model
class MLP(nn.Module):
features: int
@nn.compact
def __call__(self, x):
x = nn.Dense(features=self.features)(x)
x = nn.relu(x)
x = dense2 = nn.Dense(features=1)(x)
return x
# create an instance of the model
key = random.PRNGKey(0)
model = MLP(features=32).init(key, jnp.ones((1, 10)))
# define a function to apply to the model
def set_bias_to_zero(module):
for name, param in module.params.items():
if name.endswith('b'):
param = jnp.zeros_like(param)
return module.replace(params=module.params)
# apply the function to the model using the apply method
model = model.apply(set_bias_to_zero)
모델을 초기화하고 apply
함수를 사용하여 set_bias_to_zero
함수를 모델에 적용한다. 이렇게 하면 모델의 모든 편향 가중치가 0으로 설정된다.
이 예제에서는 set_bias_to_zero
함수가 직접적으로 MLP 모듈에 적용되지만, 더 복잡한 모델의 경우, apply
함수를 사용하여 모든 하위 모듈에 적용할 수 있다.
bind
bind
함수는 함수 인자를 바인딩(binding)하는 데 사용됩니다. 바인딩은 함수 호출에서 일부 인자의 값을 미리 설정하는 것을 의미한다. 즉, bind
함수를 사용하면 일부 인자 값을 미리 설정한 새로운 함수를 생성할 수 있다(functools의 partial
과 비슷하다고 생각하면 이해하기 쉽다).
bind 함수는 다음과 같은 구문으로 사용된다.
from flax import linen as nn
# define a Dense module with features argument
class MyDense(nn.Module):
features: int
@nn.compact
def __call__(self, x):
return nn.Dense(features=self.features)(x)
# bind the MyDense module with features=64
dense_64 = nn.bind(MyDense, features=64)
# create an instance of the new module and apply it to an input array
x = jnp.array([1.0, 2.0, 3.0])
y = dense_64().apply(x)
print(y) # prints "[0.00240129, -0.00455584, 0.00039986]"
flax.linen.bind
함수를 사용하여 features 인자 값을 64로 설정한 새로운 모듈 dense_64을 생성한다. 그런 다음, dense_64를 호출하여 새로운 모듈의 인스턴스를 생성하고, apply
함수를 사용하여 입력 배열 x에 모듈을 적용한다.
initializers
딥러닝 모델에서는 일반적으로 가중치 초기화 함수를 사용하여 모델 가중치를 초기화하고, Flax에서 이 함수는 flax.linen.initializers
모듈에서 제공된다.
from flax import linen as nn
class MyDense(nn.Module):
features: int
kernel_init: Callable = nn.initializers.lecun_normal()
bias_init: Callable = nn.initializers.zeros
@nn.compact
def __call__(self, x):
y = nn.Dense(features=self.features, kernel_init=self.kernel_init, bias_init=self.bias_init)(x)
return y
nn.initializers.lecun_normal()
함수를 사용하여 kernel_init
가중치 초기화 함수를 생성하였다. bias_init
는 nn.initializers.zeros
함수를 사용하여 초기화 함수를 생성하였다. 그런 다음, Dense 모듈에 kernel_init
과 bias_init
인자로 생성한 초기화 함수를 전달한다.
MyDense 클래스는 이제 kernel_init
과 bias_init
인자로 전달한 가중치 초기화 함수를 사용하여 모델 가중치를 초기화할 수 있다.
init
flax.linen.init() 함수를 사용하여 모델 인스턴스의 가중치를 초기화할 수 있다. init() 함수는 모델 인스턴스를 인자로 받아 초기화된 모델 인스턴스를 반환한다.
예를 들어, Dense 모듈로 구성된 간단한 신경망 모델을 만들고 이를 초기화하려면 다음과 같이 할 수 있다.
import jax
import jax.numpy as jnp
from flax import linen as nn
class MLP(nn.Module):
hidden_size: int
output_size: int
@nn.compact
def __call__(self, x):
x = nn.Dense(self.hidden_size)(x)
x = nn.relu(x)
x = nn.Dense(self.output_size)(x)
return x
model = MLP(hidden_size=64, output_size=10)
# 가중치를 초기화합니다.
rng = jax.random.PRNGKey(0)
model = nn.init(model, rng)
위 코드에서 MLP 모델 클래스를 정의하고 __call__
메서드 안에서 Dense 모듈을 사용하여 모델을 구성한다. 이후 모델 인스턴스를 생성하고, init()
함수를 사용하여 모델의 가중치를 초기화한다. init()
함수는 무작위 시드를 사용하여 가중치를 초기화하고, 초기화된 가중치가 포함된 모델 인스턴스를 반환한다.
init_with_output
이 함수를 사용하여 모델 인스턴스의 가중치를 초기화하고 모델의 출력도 함께 계산할 수 있다. init_with_output()
함수는 flax.linen.init()
함수와 유사하게 작동하지만, 초기화된 모델을 사용하여 입력 데이터를 전달하고 모델의 출력을 계산한다.
import jax
import jax.numpy as jnp
from flax import linen as nn
from jax.experimental import stax
# MLP 모델 정의
class MLP(nn.Module):
hidden_size: int
output_size: int
@nn.compact
def __call__(self, x):
x = nn.Dense(self.hidden_size)(x)
x = nn.relu(x)
x = nn.Dense(self.output_size)(x)
return x
model = MLP(hidden_size=64, output_size=10)
# 가중치 및 출력 초기화
rng = jax.random.PRNGKey(0)
x = jnp.ones((1, 784))
params, out = nn.init_with_output(rng, model, x)
print("Initialized parameters:", params)
print("Initialized output:", out)
Initialized parameters: FrozenDict({
Dense_0: {
'bias': DeviceArray([0., 0., ..., 0.], dtype=float32),
'kernel': DeviceArray([[ 0.03351291, -0.01413143, ..., 0.05929277, -0.05504522],
[ 0.04876595, 0.01508847, ..., 0.02504633, -0.03403518],
...,
[ 0.04460957, -0.02312812, ..., 0.0300643 , -0.00740753],
[-0.04049913, -0.0521089 , ..., 0.05811187, -0.05726928]], dtype=float32),
},
Dense_1: {
'bias': DeviceArray([0., 0., ..., 0.], dtype=float32),
'kernel': DeviceArray([[-0.07551454, -0.08685692, ..., 0.09087151, -0.00984192],
[-0.06961666, -0.06689157, ..., 0.00751706, -0.08733141],
...,
[-0.0646823 , -0.04226141, ..., 0.05052247, -0.03290445],
[-0.03011085, 0.08997382, ..., 0.08751475, -0.0467639 ]], dtype=float32),
},
})
Initialized output: [[-0.48999336 -0.1978706 0.01582504 0.09210121 0.11780421 -0.07384747
-0.10397934 -0.15517813 -0.05804203 -0.12006739]]
params
는 FrozenDict 형식으로 모델의 가중치를 포함하고 있다. 이전에 언급한 것처럼, params
는 모델 인스턴스의 매개변수로 사용된다. out
은 x를 사용하여 모델의 출력을 계산한 결과이다.