DeepLearning

EMA (Exponential Moving Average) 알고리즘

jiheek 2022. 7. 8. 18:16
model = create_model('efficientnet_b0')
model_m = create_model('efficientnet_b0')
momentum = 0.999
alpha = 0.5
copy_params() #copy parameters: model_m <- model 
       
    def forward(self, image, labels = None):
        output = model(image)
        m_loss = output #???
        if labels is not None: #update param of model_m
            with torch.no_grad():
                _momentum_update() #EMA
                output_m = model_m(image)
                soft_labels = F.softmax(output_m,dim=-1)
            m_loss = CrossEntropyLoss()(input=output, target=labels)

            if soft_labels is not None:
                loss_distill = -sum(F.log_softmax(output, dim=-1)*soft_labels,dim=-1)
                loss_distill = loss_distill.mean()
                m_loss = (1-alpha)*m_loss + .alpha*loss_distill
        return m_loss
   
    @torch.no_grad()        
    def _momentum_update(self):          
                param_m= param_m * momentum + param * (1 - momentum) #EMA

Gradient Descent with Momentum

 

Gradient descent의 문제는 업데이트되는 weight가 해당 순간의 learning rate와 gradient에 의해서만 결정된다는 것이다. cost space를 순회하는 동안 과거 step는 고려하지 않는다.

weight update

이는 다음과 같은 문제로 이어진다.

saddle point (source: wiki)

  1. Saddle point에서 cost function의 gradient는 거의 0이며, 결과적으로 weight 업데이트가 적거나 없다. 따라서 네트워크가 정체되고 학습이 중지된다.
  2. Gradient Descent에 의한 경로는 mini batch 모드에서도 매우 불안정하다.

https://towardsdatascience.com/gradient-descent-with-momentum-59420f626c8f

초기 가중치가 위 그림에서 A 포인트라고 가정하자. (위 그림은 loss map이다) Gradient descent를 사용하면 loss function은 AB 경사면으로 빠르게 감소한다. AB 경사면의 gradient가 매우 크기 때문이다. 하지만 B point에 도달하면 gradient가 작아지고, weight update도 줄어든다. 많은 iteration이 지난 후에도 cost는 gradient가 0이 되어 정체될 때까지 매우 느리게 이동한다.

 

이 경우에 이상적으로 cost는 global minima 포인트 C로 이동했어야 했다. 하지만 gradient가 B에서 소멸되었기 때문에, sub-optimal 솔루션에 정체되었다. 

 

 

어떻게 Momentum이 이를 해결할 수 있을까?

이제 point A로부터 공이 굴러간다고 상상해보자. 공이 천천히 굴러가기 시작하고, AB 경사면에서 momentum을 수집한다. 공이 B 포인트에 도달했을 때, 공은 충분한 모멘텀(momentum)을 모은 상태이며, 모멘텀으로 B 포인트를 벗어나 BC 경사면으로, global minima인 C 포인트까지 도달할 수 있다.

 

Gradient Descent + Momentum

모멘텀을 설명하기 위해서, 과거 gradient들에 대해 moving average를 사용할 수 있다.