Flax 사용법 요약
설치
!pip install flax
!pip install --upgrade pip jax jaxlib
!pip install --upgrade git+https://github.com/google/flax.git
JAX와 Flax 버전을 맞춰줘야하는데 이게 까다롭기 때문에 colab에서 돌린다면 위 코드를 통해 모든 버전을 최신으로 올려주면 된다(23년 1월을 기준으로 아직 colab에 Flax가 기본으로 설치되어 있지 않음).
use with
사람마다 다르겠지만 나는 Flax가 Keras와 Pytorch 그 중간 어딘가라고 느껴진다(모델 구현하는 구조는 Pytorch에 좀 더 가까운 것 같다). Flax는 Numpy 대신 JAX
를 사용하며, Optimizer는 optax
라는 라이브러리를 사용한다. 그러므로 Flax를 사용한다=JAX+Optax+Flax를 사용한다라고 생각하면 된다.
또한 테스트 데이터를 제공해주지 않기 때문에 MNIST와 같은 기본 데이터로 테스트해보고 싶다면 keras의 모듈을 사용해야한다.
Class 구조
# 간단한 residual 블록
import flax
import jax
from flax import linen as nn
import optax
class ResidualBlock(nn.Module):
# 1.--------------------------
num_channels:int
use_1x1conv:bool = True
strides:tuple = (1,1)
training:bool = True
# 2.--------------------------
def setup(self):
self.conv1 = nn.Conv(self.num_channels,
kernel_size=(3,3),
strides=self.strides,
use_bias=False)
self.conv2 = nn.Conv(self.num_channels,
kernel_size=(3,3),
use_bias=False)
if self.use_1x1conv:
self.conv3 = nn.Conv(self.num_channels,
kernel_size=(1, 1),
strides=self.strides,
use_bias=False)
else:
self.conv3 = None
self.bn1 = nn.BatchNorm(not self.training)
self.bn2 = nn.BatchNorm(not self.training)
# 3.--------------------------
def __call__(self, x):
z = self.conv1(x)
z = self.bn1(z)
z = nn.relu(z)
z = self.conv2(z)
z = self.bn2(z)
if self.conv3:
x_out = self.conv3(x) + z
else:
x_out = z
return nn.relu(x_out)
JAX의 모델 Class는 크게 3가지 파트로 나눌 수 있다.
1. 클래스 변수
_JAX_는 _Pytorch_와 다르게 __init__
을 사용하지 않는다. 클래스 변수와 setup
을 통해 그 기능을 대신한다. 클래스 변수에서는 생성자 생성시 받고 싶은 변수를 정의해주면 된다.
2. setup
setup에서는 __call__
에서 사용할 변수들의 기능을 정의해주면 된다. 필수로 사용해야하는 것은 아니며, __call__
에서 정의해주어도 된다. Pytorch_와 비슷한 방식으로 선언함으로써 _Pytorch 모델을 수얼하게 이식할 수 있다.
3. call
_Pytorch_의 forward 함수의 역할을 하는 함수이다. 실질적인 모델의 구조를 결정하며 return이 필수로 필요하다. 위 setup
함수에서 정의한 변수를 통해 모델을 구성해도되고, __call__
에서 직접 정의해도 된다.
@nn.compact
만약 setup을 사용하지 않는다면 아래와 같이 nn.compact
데코레이터를 사용해주어야 한다. 위 코드에서는 setup 설명을 위해 사용하지 않았지만, 많은 JAX 코드에서 nn.compact
를 사용하는 것을 볼 수 있다. 읽기 더 수월하고 코드 중복을 줄일 수 있기 때문에 많이들 사용하는 것 같다.
class ResidualBlock(nn.Module):
num_channels:int
use_1x1conv:bool = True
strides:tuple = (1,1)
training:bool = True
@nn.compact
def __call__(self, x):
z = nn.Conv(self.num_channels,
kernel_size=(3,3),
strides=self.strides,
use_bias=False)(x)
z = nn.BatchNorm(not self.training)(z)
z = nn.relu(z)
z = nn.Conv(self.num_channels,
kernel_size=(3,3),
use_bias=False)(z)
z = nn.BatchNorm(not self.training)(z)
if self.use_1x1conv:
x_out = nn.Conv(self.num_channels,
kernel_size=(1, 1),
strides=self.strides,
use_bias=False)(x)
x_out += z
else:
x_out = z
return nn.relu(x_out)
모델 확인해보기
만든 모델에 입력이 잘들어가는지 출력해보고 싶다면 init_with_output
을 사용하여 출력해볼 수 있다.
from jax import random
# JAX에서는 PRNGKey를 만들어서 state를 유지해야함
key = random.PRNGKey(42)
r_bl = ResidualBlock(6, use_1x1conv=True, strides=(2,2))
x = random.normal(key, (4, 6, 6, 3))
r_bl.init_with_output(key, x)[0].shape
(4, 3, 3, 6)