I am Charmie

メモとログ

PyTorch Lightning

PyTorch Lightningを使ってみた.

https://pytorch-lightning.readthedocs.io/en/0.7.1/_images/pt_to_pl.jpg

  • この解説がわかりやすかった
  • PyTorchにおけるBoilerplate codeを減らすための工夫が施されている
  • 学習の処理(一バッチ分)をモデルクラスの関数に書くことで,学習の二重ループを書かなくて済む
    • lossfun: ロス関数の定義
    • configure_optimizers: optimizerの定義
    • train_dataloader: 学習データのデータローダの設定(@pl.data_loadeで修飾する必要あり)
    • val_dataloader: 検証データのデータローダの設定(@pl.data_loadeで修飾する必要あり)
    • training_step: 一バッチ分の学習ステップを定義
    • validation_step: 一バッチ分の検証ステップを定義
    • validation_end: 一エポック分の検証ステップ終了時の計算を定義
  • optimizer,データローダについてもモデルクラスの関数に書けるんだけど気持ち悪い.何か思想的な部分で理解が足りないんだと思う
  • 色々なレシピを公開してくれているのもありがたい
  • optimizerの説明を見たけど,optimizerを途中で切り替える方法が分からなかった.
    • やりたいこと: 学習の途中で更新するパラメータ,学習率を切り替えたい
    • この方法on_train_epoch_startでエポック数に応じて切り替える)が一番わかり易い

サンプルプログラムは以下の通りで,学習のプログラムが簡潔に書ける.

# main.py
# ! pip install torchvision
import torch, torch.nn as nn, torch.utils.data as data, torchvision as tv, torch.nn.functional as F
import lightning as L

# --------------------------------
# Step 1: Define a LightningModule
# --------------------------------
# A LightningModule (nn.Module subclass) defines a full *system*
# (ie: an LLM, diffusion model, autoencoder, or simple image classifier).


class LitAutoEncoder(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
        self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        return embedding

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop. It is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


# -------------------
# Step 2: Define data
# -------------------
dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
train, val = data.random_split(dataset, [55000, 5000])

# -------------------
# Step 3: Train
# -------------------
autoencoder = LitAutoEncoder()
trainer = L.Trainer()
trainer.fit(autoencoder, data.DataLoader(train), data.DataLoader(val))