JLOG

Loss Function 정리(MSE, Cross Entropy error) 본문

AI/basic concepts of AI

Loss Function 정리(MSE, Cross Entropy error)

정정선선 2021. 3. 9. 11:59

손실 함수란?

학습을 통해 최적 가중치 매개변수(W)를 결정하기 위한 지표로 손실함수(loss function)을 사용한다.

오차(loss, cost) 값을 작게 만드는 것이 신경망 학습의 목표이다.

 

Gradient Descent

loss function의 gradient(기울기)를 이용하여 loss가 최소화 되는 방향으로 학습시키는 것

위와 같이 loss function의 기울기가 -라면 loss 값이 최소가 되는 방향인 +방향으로 이동하고,

loss function의 기울기가 +라면 loss 값이 최소가 되는 방향인 -방향으로 이동할 것이다.

 

 

 

평균 제곱 오차(mean square error, MSE)

yk : 신경망 출력

tk : 정답 레이블(one-hot encoding 형식)

n : 데이터의 개수

 

식에는 안나와있지만, one-hot encoding을 사용하는 경우 최대 Error를 1로 정규화하기 위하여 2를 나눠주는 과정을 가진다.

 

설명 사진)

코드

def mse(y, t) : 
    return np.average(np.sum(np.square(y - t))) / 2

주의할 것은 y와 t의 input shape가 (num_image, num_class)가 되어야지 정상적으로 작동한다.

 

WORST CASE에 관한 고찰

MSE의 최대 값은 1이라고 한다. one-hot encoding으로 mse를 구하면 최대 값이 2가 나옴으로 정규화를 위해 2를 나눠준다고 한다.

 

그래서 위와 같은 식이 나온 것인데, 확인하기 위해 sklearn mse와 비교해보았다.

 

sklearn의 MSE의 경우 image data 개수로 나눠주는 것이 아니라 class의 개수로 나누어주어 답이 다르게 나왔다. one-hot encoding으로 생각하지 않고, 각 클래스들을 하나의 독립된 sample로 생각해서 답이 다른 것 같다.

# sklearn MSE input parameter 설명
'''
y_true : array-like of shape (n_samples,) or (n_samples, n_outputs)
        Ground truth (correct) target values.
y_pred : array-like of shape (n_samples,) or (n_samples, n_outputs)
        Estimated target values.
'''

# sklearn MSE 코드 발췌 부분
output_errors = np.average((y_true - y_pred) ** 2, axis=0,
                               weights=sample_weight)

그래서 input shape가 (num_image, num_class)에 해당한다면 위에서 작성한 코드 식이 맞다.

one-hot encoding 형식을 이용하고 max 값을 1로 설정하기 위해서 정규화가 필요한 것이 맞다고 생각한다.

 

 

 

교차 엔트로피 오차(Cross Entropy Error, CEE)

yk : 신경망 출력

tk : 정답 레이블(one-hot encoding 형식)

n : 데이터의 개수

 

tk은 정답 레이블만 1이고 나머지 레이블은 0인 one hot encoding 형식이므로,

정답 레이블 index에 -log를 취한 것이 된다.

 

y=-log(x) graph를 보면, yk가 0에 가까울 수록 loss 값이 급격하게 증가함을 확인할 수 있다.

 

def cee(y, t) : 
	delta = 1e-7
	return -np.sum(t*np.log(y+delta))

worst case인 경우 loss가 inf(무한대) 값이 되므로, 작은 delta(1e-7)을 더해서 Inf가 발생하지 않도록 한다.

 

 

 

reference

  • sklearn mse

https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_squared_error.html

 

 

 

+)참고하면 좋을 자료 - 5개 Loss fucntion에 대해 매우 자세하게 설명한 자료

https://heartbeat.fritz.ai/5-regression-loss-functions-all-machine-learners-should-know-4fb140e9d4b0

 

Comments