Home Improving regression performance with distributional losses(ICML 2018)
Post
Cancel

Improving regression performance with distributional losses(ICML 2018)


Knowledge Distillation 분야가 요즘 핫한거 같다. 직독직해하자면 지식을 전달해주는 것이다. 똑똑한 모델의 지식을 멍청한 모델에게 전달해주는 분야이다. 오늘 리뷰하려는 논문은 아니니 컨셉만 말해보자면, 어떤 Task에 아주 학습이 잘 된 Layer 100개짜리 무거운 모델이 있다고 가정해보자. 우리는 이것을 실생활에서도 사용할 수 있도록 가벼운 버젼으로 개발하고 싶다. 그럼 이 상황에서 가벼운 모델을 학습할 때는 그냥 맨땅에 헤딩하듯 Task를 가르치는 것이 아니라, 미리 학습해 놓은 무거운 모델의 Output인 확률 벡터를 흉내내도록 학습 시키는 것이다. 이것을 일반적으로 Teacher - Student 관계로 정의하고 아래 사진과 같은 방식으로 학습시켜준다.



그런데 사실 이건 모두 서론이고 Knowledge Distillation이 왜 가능한지만 이해하면 된다. 선생님 모델이 처음 MNIST 데이터를 학습한 시점의 정답 라벨은 [1 0 0 0 0 0 0 0] 과 같이 극단적인 모습이다. 그러나 선생님 모델이 내놓은 Output은 1과 7에는 유사한 확률을, 1과 8에는 다른 확률을 내놓는 식의 똑똑한 지식이 담겨져 있다.

즉, 확률 벡터와 같이 지식이 담겨져 있는 Soft Label을 사용하는 것이 여러 분류 문제의 성능 향상에 기여한다는 것이다. Soft Label을 사용하는 방법에는 단순히 노이즈를 추가하거나 위에 설명한 것과 같이 Knowledge Distillation을 사용하는 등의 방법이 있다.

이미 강화학습 분야에서는 모델에게 분포를 학습하는 것이 왜 성능을 향상시키는가에 대해 수학적으로 증명하였다고 하는데 큰 관심은 없다. 하여튼, 해당 논문에서는 이것을 일반화하여 범용적인 Regression 문제에서 적용하려 한다. Soft Label을 Regression Task에 적용한 새로운 형태의 Loss를 제안한다. 단순히 제안과 실험에서 끝내는 것이 아니라 성능 향상의 이유를 수학적으로 추론한 부분이 특히 흥미롭다.


Paper Summary

  1. Distributional Losses를 통해 Regression 문제의 예측 성능을 향상시킬 수 있다.

  2. 우리는 해당 로스가 왜 더 좋은 성능을 발휘하는가에 대해 수학적으로 추론했다.

  3. 논문에서 제안된 Distributional Losses = Histogram Loss는 계산이 용이하다.


Histogram Loss

논문에서 제안한 Histogram Loss(HS)는 Input X에 대한 Ouput Y의 분포를 예측하도록 학습된다. 그리고 True라고 믿는 Y 분포와 Predict Y’ 분포간의 KL-Divergence의 크기가 Loss가 된다. 처음에는 Output의 형태가 직관적으로 와닿지 않을 수 있지만 로스의 이름이 Histogram이라는 것을 기억하자. 우리가 학습할 모델은 히스토그램을 아웃풋으로 내놓게 된다. 쉽게 말하자면, 히스토그램의 막대 갯수만큼의 숫자를 내놓게 된다. 정확히 말하자면, 합이 1이며 [0,1] 사이의 k(# of bins) 길이의 벡터를 내놓게 된다.



불행하게도 논문에서는 직관적으로 이것을 이해시켜줄 어떤 그림도 없다. 저자의 머릿속에 있었을 그림이 위의 느낌일 것이라 추측해본다. Histogram Loss가 목적하는 것은 연속적인 형태의 True 분포와 아웃풋으로 내놓은 히스토그램이 비슷한 것이다. 각 원소가 density를 의미하는 k 길이의 벡터가 분포의 안에 쏙 들어오는 것이다. 이것을 KL-Divergece를 사용하여 표현하면 아래와 같다.


  • 𝐷_𝐾𝐿 (𝑃//𝑄)= ∑_(𝑥∈𝑋)〖𝑝(𝑥)∗log⁡((𝑄(𝑥))/(𝑃(𝑥)))〗

    = ∑〖𝑝(𝑥)∗log⁡(𝑞(𝑥))〗− ∑〖𝑝(𝑥)∗log⁡(𝑝(𝑥))〗

    = 𝐻(𝑃,𝑄) − 𝐻(𝑃)


  • 우리는 Q에 대한 Gradient를 사용할 것이니 Seconde Term은 버린다

    = −∫〖𝑝(𝑦)∗log⁡(𝑞_𝑥 (𝑦))𝑑𝑦〗 _

    _= −∑ ∫〖𝑝(𝑦)∗log⁡〖(𝑓_𝑖 (𝑥))/𝑤_𝑖 〗𝑑𝑦〗


  • 우리의 아웃풋 분포 Q는 히스토그램이므로 integral이 아닌 시그마 형태로 바꿀 수 있다

    = −∑log⁡〖(𝑓_𝑖 (𝑥))/𝑤_𝑖 〗*(𝐹(𝑙_𝑖 + 𝑤_𝑖)− 𝐹(𝑙_𝑖))


  • 우리의 타겟 분포 P는 연속이기 때문에 integral 속의 p를 cdf를 사용하여 표현한다

    ≈ −∑𝑝_𝑖 *log(𝑓_𝑖 (𝑥))

히스토그램 bin의 width는 상수이므로 버리고나면 HS는 k개의 이산적인 entropy 꼴이 된다


이렇게 우리의 로스를 이해하고 나면, True Distribution P를 무엇으로 믿을것인가에 대한 문제만 남는다. 베이지안 추론에서 Prior Probability를 적절히 설정하는 것이 중요하듯 좋은 P를 설정하는 것은 중요하다. 논문에서는 Truncated Normal Distribution / Dirac delta Distribution을 기본적으로 제안하고 Teacher Model이 존재하는 Classfication Task라면 soft label을 True P로 가정한다.

정규 분포를 사용하지 않고 Truncated 분포를 사용하는 이유는 히스토그램인 Q와 구간의 범위를 맞추기 위함이다. 이렇듯, 다양한 P를 사용할 수 있어 상황에 따른 풍부한 해석이 가능하다고 저자는 말하고 있는데, 아쉽게도 논문의 모든 실험에서 Truncated Normal Distribution이 가장 좋은 성적을 보여 저자가 말한 상황이 무엇인지는 알 수 없었다.




위와 같이 Truncated 정규 분포의 확률 값 p(i)를 구할 수 있다. k개 구간의 확률을 한 번만 미리 구해놓으면, 우리의 f가 내놓는 k길이의 density ouput과의 연산을 반복해서 편하게 수행할 수 있다. 이것이 바로 저자가 Easy to Compute라고 주장한 이유란다. 물론 단순히 mse를 계산하는 것보다야 훨씬 코스트가 크겠지만 기존의 다른 방법론에서 사용되던 연속 분포 두 개의 divergence를 계산하는 것보다야 훨씬 빠르다고 한다.

또한 해당 로스의 가장 큰 장점은 빠르게 수렴하는 특성에 있는데 stochastic한 학습에서 빠르게 수렴하는 것이 일반화에 큰 도움이 된다고 한다. Loss의 Upper Bound가 작을수록 빠르게 수렴한다는 것이 15년도 논문에서 증명했다고 하니 다음에 읽어봐야겠다. 하여튼 해당 로스가 빠르게 수렴 할 것이라는 추론은 아래의 수학적 논리를 통해 증명할 수 있다.



Appendix까지 추가할 것이었으면 조금 더 친절하게 증명해줘도 되지 않았을까라는 생각이 들지만, 분수꼴의 미분과 편미분을 할 수 있다면 천천히 읽다보면 금방 이해할 수 있다. 결론만 말하자면 HS의 Gradient는 ∑p(i) - f(x_i)의 배수를 Upper Bound로 갖고, 같은 방식으로 MSE loss의 Upper bound를 계산하면 f(x) - y라는 것이다. HS의 것은 커봤자 1을 넘을 수 없지만 MSE의 로스는 매우매우 커질 수 있고, 이것은 빠르게 수렴하는데 방해가 될 것이라는 것이 저자의 주요 논지이다.



위의 실험이 그것을 증명하는데, 맨 왼쪽의 그래프를 통해 Truncated Normal Distribtion을 사용한 HS가 여러 비교군에 비해 빠르고 안정되게 학습되는 것을 볼 수 있다. 또한 옆의 오른쪽 두 그래프를 통해서는 빠르고 안정적일 뿐 아니라 더 정확하다는 것도 보여주고 있다.

논문을 읽으며 매우 흥미로웠기 때문에 많은 부분에 활용되지 않았을까 생각했지만 논문의 신선함에 비해서는 Cited된 수가 너무 적어 의외였다. 최근에 이와 비슷한 Distributional Loss를 활용하여 시계열 모델을 적합하는 논문이 나왔다고 하여 다음에 리뷰해보겠다.

(Paper URL)Improving regression performance with distributional losses

This post is licensed under CC BY 4.0 by the author.

-

Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks (ICML 2017)