설치

!pip install flax
!pip install --upgrade pip jax jaxlib
!pip install --upgrade git+https://github.com/google/flax.git

JAX와 Flax 버전을 맞춰줘야하는데 이게 까다롭기 때문에 colab에서 돌린다면 위 코드를 통해 모든 버전을 최신으로 올려주면 된다(23년 1월을 기준으로 아직 colab에 Flax가 기본으로 설치되어 있지 않음).

use with

사람마다 다르겠지만 나는 Flax가 Keras와 Pytorch 그 중간 어딘가라고 느껴진다(모델 구현하는 구조는 Pytorch에 좀 더 가까운 것 같다). Flax는 Numpy 대신 JAX를 사용하며, Optimizer는 optax라는 라이브러리를 사용한다. 그러므로 Flax를 사용한다=JAX+Optax+Flax를 사용한다라고 생각하면 된다.

또한 테스트 데이터를 제공해주지 않기 때문에 MNIST와 같은 기본 데이터로 테스트해보고 싶다면 keras의 모듈을 사용해야한다.

Class 구조

# 간단한 residual 블록
import flax
import jax

from flax import linen as nn
import optax

class ResidualBlock(nn.Module):
	# 1.--------------------------
    num_channels:int
    use_1x1conv:bool = True
    strides:tuple = (1,1)
    training:bool = True
	
    # 2.--------------------------
    def setup(self):
        self.conv1 = nn.Conv(self.num_channels,
                             kernel_size=(3,3),
                             strides=self.strides, 
                             use_bias=False)
        self.conv2 = nn.Conv(self.num_channels,
                             kernel_size=(3,3),
                             use_bias=False)

        if self.use_1x1conv:
            self.conv3 = nn.Conv(self.num_channels,
                                 kernel_size=(1, 1),
                                 strides=self.strides,
                                 use_bias=False)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm(not self.training)
        self.bn2 = nn.BatchNorm(not self.training)

	# 3.--------------------------
    def __call__(self, x):
        z = self.conv1(x)
        z = self.bn1(z)
        z = nn.relu(z)

        z = self.conv2(z)
        z = self.bn2(z)

        if self.conv3:
            x_out = self.conv3(x) + z
        else:
            x_out = z

        return nn.relu(x_out)

JAX의 모델 Class는 크게 3가지 파트로 나눌 수 있다.

1. 클래스 변수

_JAX_는 _Pytorch_와 다르게 __init__을 사용하지 않는다. 클래스 변수와 setup을 통해 그 기능을 대신한다. 클래스 변수에서는 생성자 생성시 받고 싶은 변수를 정의해주면 된다.

2. setup

setup에서는 __call__에서 사용할 변수들의 기능을 정의해주면 된다. 필수로 사용해야하는 것은 아니며, __call__에서 정의해주어도 된다. Pytorch_와 비슷한 방식으로 선언함으로써 _Pytorch 모델을 수얼하게 이식할 수 있다.

3. call

_Pytorch_의 forward 함수의 역할을 하는 함수이다. 실질적인 모델의 구조를 결정하며 return이 필수로 필요하다. 위 setup 함수에서 정의한 변수를 통해 모델을 구성해도되고, __call__에서 직접 정의해도 된다.

@nn.compact

만약 setup을 사용하지 않는다면 아래와 같이 nn.compact 데코레이터를 사용해주어야 한다. 위 코드에서는 setup 설명을 위해 사용하지 않았지만, 많은 JAX 코드에서 nn.compact를 사용하는 것을 볼 수 있다. 읽기 더 수월하고 코드 중복을 줄일 수 있기 때문에 많이들 사용하는 것 같다.

class ResidualBlock(nn.Module):
    num_channels:int
    use_1x1conv:bool = True
    strides:tuple = (1,1)
    training:bool = True

    @nn.compact
    def __call__(self, x):
        z = nn.Conv(self.num_channels,
                    kernel_size=(3,3),
                    strides=self.strides, 
                    use_bias=False)(x)
        z = nn.BatchNorm(not self.training)(z)
        z = nn.relu(z)

        z = nn.Conv(self.num_channels,
                    kernel_size=(3,3),
                    use_bias=False)(z)
        z = nn.BatchNorm(not self.training)(z)

        if self.use_1x1conv:
            x_out = nn.Conv(self.num_channels,
                            kernel_size=(1, 1),
                            strides=self.strides,
                            use_bias=False)(x)
            x_out += z
        else:
            x_out = z

        return nn.relu(x_out)

모델 확인해보기

만든 모델에 입력이 잘들어가는지 출력해보고 싶다면 init_with_output을 사용하여 출력해볼 수 있다.

from jax import random

# JAX에서는 PRNGKey를 만들어서 state를 유지해야함
key = random.PRNGKey(42)

r_bl = ResidualBlock(6, use_1x1conv=True, strides=(2,2))
x = random.normal(key, (4, 6, 6, 3))
r_bl.init_with_output(key, x)[0].shape
(4, 3, 3, 6)