Home [Pytorch] Learning rate 관리하기
Post
Cancel

[Pytorch] Learning rate 관리하기

PyTorch, Pytorch-lightning을 이용해서 프로젝트를 진행하고 있는데
자주 쓸 것 같지만 정리해두지 않으면 까먹을 것 같은 트릭들의 모음

Learning rate관리는 training task의 핵심 중 하나이다. 특히 Pre-training task를 진행할 땐, learning rate를 잘 관리해주지 않으면 모델이 제대로 수렴되지 않는다. 이번 포스트에서는 다음 두개를 집중해서 정리한다.

  • Learning rate scheduler
  • Learning rate monitor

Learning rate scheduler

torch.optim.lr_scheduler 내에는 간편하게 사용할 수 있는 scheduler class들이 정의되어 있다.

Pytorch 공식 documentation here 에서 확인할 수 있고
나는 이 블로그 here를 많이 참고했었었다.

자세한 Learning rate scheduler들은 위 링크에 잘 정리되어있으니 넘어가도록 한다.

InverseSqrtScheduler

나는 이번 프로젝트에서 facebook research팀이 ESM model을 training할 때 사용한 Inverse Square Root with Warmup scheduler를 구현해서 썼다. (fairseq github)의 구현을 참고해서 작성했다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from torch.optim.lr_scheduler import LambdaLR


class InverseSqrtScheduler(LambdaLR):
    """ Linear warmup and then follows an inverse square root decay schedule
        Linearly increases learning rate schedule from 0 to 1 over `warmup_steps` training steps.
        Afterward, learning rate follows an inverse square root decay schedule.
    """

    def __init__(self, optimizer, warmup_steps, last_epoch=-1):
        def lr_lambda(step):
            if step < warmup_steps:
                return float(step) / float(max(1.0, warmup_steps))
            
            decay_factor = warmup_steps ** 0.5
            return decay_factor * step ** -0.5

        super(InverseSqrtScheduler, self).__init__(optimizer, lr_lambda, last_epoch=last_epoch)

Lambda scheduler를 사용할 때, lambda 함수는 정확한 learning rate를 반환하는게 아닌
maximum learning rate에 곱해질 factor (0 ~ 1) 값을 반환해주면 된다

이 scheduler를 사용하면 다음과 같은 형태로 learning rate가 관리된다.

inverse_sqrt

configure_optimizers()

이렇게 만든 scheduler를 pytorch-lightning에서 사용하는 것은 아주 간단하다. configure_optimizers() 함수 내에서 다음처럼 써주기만 하면 된다.

1
2
3
4
5
def configure_optimizers():
	optimizer = AdamW( ... )
	scheduler = InverseSqrtScheduler(optimizer, self.hparams.warmup_steps)

	return [optimizer], [scheduler]

그런데 한가지 주의해야할 사항이 있다. 위와같이 사용하면 learning rate가 epoch 단위로 관리되는 것이다. fine-tuning등 일반적인 training task에서는 굳이 learning rate를 step단위로 관리할 필요가 없지만, masked language model 이나 GAN과 같은 모델에서는 이야기가 다르다. 한 epoch가 굉장히 길기 때문에, step단위로 learning rate을 조절해 줄 필요가 있다. (굉장히 쉬운 테크닉인데, 검색해도 잘 나오지 않아서 하루종일 헤맸다 ㅠㅠ)

1
2
3
4
5
6
7
8
9
def configure_optimizers():
	optimizer = AdamW( ... )
	scheduler = InverseSqrtScheduler(optimizer, self.hparams.warmup_steps)
	sch_config = {
		"scheduler": scheduler,
		"interval": "step",
	}

	return [optimizer], [sch_config]

아주 작은 수정인데, intervalstep으로 설정해주면 scheduler의 step()함수가 매 epoch대신 매 step마다 호출된다. return [optimizer], [sch_config] 대신 다음과 같이 써도 괜찮다.

1
2
3
4
return {
	"optimizer": optimizer,
	"lr_scheduler": sch_config,
}

Learning rate monitor

log()

Learning rate가 잘 변하고 있는지 지켜보고 싶은 심정은 모두 똑같을 것이다. learning rate를 확인하는 직관적인 방법은 다음처럼 할 수 있다. training_step() 또는 validation_step()에서 log를 찍는 것이다.

1
2
3
4
5
6
7
8
def training_step():
	scheduler = self.lr_schedulers()
	loss = self.model(**batch).loss
	
	self.log('train_loss', loss, on_epoch=True)
	self.log('learning rate', scheduler.get_lr()[0], on_epoch=True)
	
	return loss

Callback

그런데 pytorch-lightning을 쓰는 이유 중 하나는 불필요한 코드들을 간단하게 정리하기 위함이 아닌가!! 저렇게 주저리주저리 코드를 다는 대신, 이미 구현되어있는 Callback을 가져다쓰는 편한 방법이 있다.

1
2
3
4
5
6
7
8
9
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor

lr_monitor = LearningRateMonitor(logging_interval='step')
trainer = pl.Trainer(
	...
	callbacks = [lr_monitor],
	...
)

Scheduler Name

LearningRateMonitor를 이용해 모니터링한 값은 tensorboard를 이용해 확인할 수 있다. 기본적으로 optimizer의 이름으로 표시되며, 특별한 이름을 부여하고 싶다면 다음처럼 scheduler를 수정해주면 된다.

1
2
3
4
5
6
7
8
9
10
def config_optimizers():
	optimizer = AdamW( ... )
	scheduler = InverseSqrtScheduler(optimizer, self.hparams.warmup_steps)
	sch_config = {
		"scheduler": scheduler,
		"interval": "step",
		"name": "my_little_scheduler",
	}

	return [optimizer], [sch_config]
This post is licensed under CC BY 4.0 by the author.