Flax lenen module 살펴보기
ValueError: parent must be None, Module or Scope
ValueError: parent must be None, Module or Scope
모델 구현 부분에서 다음과 같은 Error를 마주하는 경우가 종종 있다. 이 이슈의 경우 아래와 같이 모델의 입력이 아닌 입력을 넣어줬을 때 발생한다.
class Bar(nn.Module):
@nn.compact
def __call__(self, x):
pass
>>> Bar('test')
...
ValueError: parent must be None, Module or Scope
해결 방법
__setup__
이나클래스 변수
로 변수를 받아주거나- 생성자를 생성할 때 괄호를 잘못 넣어준게 아닌가 살펴보자. 예를 들어,
ResNetBlock(x)
가 아니라ResNetBlock()(x)
와 같은 형식으로 코드를 작성해야한다.
ScopeCollectionNotFound: Tried to access “mean” from collection “batch_stats” in “{YOUR_NETWORK}” but the collection is empty.
TypeError: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.inexact), but got int32. If you want to use Boolean- or integer-valued inputs, use vjp or set allow_int to True.
해결 방법
JAX에서 gradient를 계산하던 중 만난 에러. 에러의 설명처럼 allow_int=True
를 넣어주면 해결된다.
# 오류
(loss, (logits, updates)), grads = jax.value_and_grad(loss_fn, has_aux=True)(state, inputs, labels)
# 해결
(loss, (logits, updates)), grads = jax.value_and_grad(loss_fn, has_aux=True, allow_int=True)(state, inputs, labels)
TypeError: Gradient only defined for scalar-output functions. Output had shape: (128,).
JAX는 입력으로 주어진 함수에 대한 자동 미분(autodifferentiation)을 지원한다. grad()
함수를 사용하면 주어진 함수에 대한 그래디언트(gradient)를 계산할 수 있다.
그러나 grad()
함수는 기본적으로 스칼라 값 함수에 대해서만 작동한다. 따라서, “Gradient only defined for scalar-output functions.“와 같은 오류가 발생하는 경우에는 함수의 출력값이 스칼라 값이 아닌 벡터, 행렬 또는 텐서 등의 형태를 가지고 있기 때문이다.
해결 방법
이 경우, 해결책은 출력값이 스칼라 값이 되도록 입력 함수를 변경하는 것이다. 이를 위해서는 주어진 함수를 각각의 출력 차원에 대해 개별적으로 계산하는 함수를 만들어야 한다. 이를 일반적으로 vmap 또는 batching이라고 한다.
JAX에서는 vmap()
함수를 사용하여 벡터화된 함수(vectorized function)를 만들 수 있다. 이 함수를 사용하면 다차원 배열을 입력으로 받아 각각의 입력에 대해 함수를 실행하고 결과를 다차원 배열 형태로 반환할 수 있다. 이렇게 만들어진 벡터화된 함수는 grad()
함수에 대한 입력으로 사용할 수 있다.
예를 들어, 다음과 같은 함수가 있다고 가정해보자.
import jax.numpy as jnp
def foo(x, y):
return jnp.dot(x, y)
이 함수는 x와 y의 내적을 계산하는 함수이다. vmap()
함수를 사용하여 이 함수를 벡터화하고자 할 때, 다음과 같이 작성할 수 있다.
from jax import vmap
# vmap을 사용하여 foo 함수를 벡터화합니다.
vectorized_foo = vmap(foo, in_axes=(0, 0))
# 두 개의 배열을 준비합니다.
x = jnp.array([[1, 2], [3, 4], [5, 6]])
y = jnp.array([[2, 3], [4, 5], [6, 7]])
# 벡터화된 함수를 호출합니다.
result = vectorized_foo(x, y) # 결과: array([ 8, 26, 56], dtype=int32)
위의 예시에서 in_axes=(0, 0)
는 foo
함수의 인자 x와 y가 모두 첫 번째 차원에서 반복되도록 벡터화한다는 것을 의미한다. 이제 foo
함수의 입력 x와 y는 각각 (3, 2) 모양의 행렬이며, vmap()
을 사용하여 각 행의 내적을 계산한 결과인 (3,) 모양의 벡터가 반환된다.
ValueError: Custom node type mismatch: expected type:
TrainState 클래스에 model 추가해줌