hydraは階層構造を持つ設定を扱うためのPythonフレームワーク
設定ファイルを自由自在に読み込む
以下のように hydra.initialize()
で設定ファイルのディレクトリを指定して, hydra.compose()
で設定ファイルを読み込む
from hydra import initialize, compose
from omegaconf import OmegaConf
with initialize(version_base=None, config_path="conf/sub"):
cfg1 = compose(config_name="setting_1.yaml")
print(OmegaConf.to_yaml(cfg1))
cfg2 = compose(config_name="setting_2.yaml")
print(OmegaConf.to_yaml(cfg2))
以下のように,yaml形式で保存した設定を取得できる
Config (conf/config.yaml
)
db:
driver: mysql
user: omry
pass: secret
Application (main.py
)
import hydra
from omegaconf import DictConfig, OmegaConf
@hydra.main(version_base=None, config_path="conf", config_name="config")
def my_app(cfg : DictConfig) -> None:
print(OmegaConf.to_yaml(cfg))
if __name__ == "__main__":
my_app()
- サンプルプログラムの設定をyaml形式で保存したファイルは以下の通り.
???
という文字列は {AnyNode} ??? <MISSING>
という値が登録されていたので,NaN的な位置付けなんだと思う.
batch_size: 64
test_batch_size: 1000
epochs: 14
no_cuda: false
dry_run: false
seed: 1
log_interval: 10
save_model: false
checkpoint_name: unnamed.pt
adadelta:
_target_: torch.optim.adadelta.Adadelta
params: ???
lr: 1.0
rho: 0.9
eps: 1.0e-06
weight_decay: 0
steplr:
_target_: torch.optim.lr_scheduler.StepLR
optimizer: ???
step_size: 1
gamma: 0.1
last_epoch: -1
階層構造を持つときの対応
@dataclass
で定義したConfigクラスにmutable なデフォルト値を設定しようとするとこんなエラーが発生
- Configクラスを入れ子式にしてデフォルト値を設定しようとすると
strオブジェクトは()使えない
的なエラーが発生した
- Configクラスのオブジェクトをデフォルト値として設定するときに注意する必要あり
- 公式チュートリアルのコードを参考にすればいけた.
from dataclasses import dataclass, field
from hydra_configs.torch.optim import AdadeltaConf
from hydra_configs.torch.optim.lr_scheduler import StepLRConf
@dataclass
class MNISTConf:
batch_size: int = 64
test_batch_size: int = 1000
epochs: int = 14
no_cuda: bool = False
dry_run: bool = False
seed: int = 1
log_interval: int = 10
save_model: bool = False
checkpoint_name: str = "unnamed.pt"
adadelta: AdadeltaConf = field(default_factory=AdadeltaConf)
steplr: StepLRConf = field(default_factory=lambda: StepLRConf(step_size=1))
- サンプル
hydra.utils.instantiate()
: インスタンス生成 (call()のエイリアス)
hydra.utils.call()
: 関数の実行
hydra.utils.instantiate()
を使うことで,import
しないでオブジェクトを生成できる
- 関数への引数を上書きしたい場合は
instantiate()
に上書きしたい引数名・引数を入れる (override parmeters って表現してる)
optimizer:
_target_: my_app.Optimizer
algo: SGD
lr: 0.01
@hydra.main(version_base=None, config_path='conf/', config_name='config')
def main(cfg):
opt = instantiate(cfg.optimizer)
opt = instantiate(cfg.optimizer, lr=0.2)
- PyTorch用のconfigを扱えるようにするパッケージ
- チュートリアルではyaml形式ではなくpyファイル内に設定をベタ書きする方法を紹介
- yaml形式で設定を保存する方法が分からなかったので,読み込んだ設定を
OmegaConf.save()
で保存したファイルを使うようにした
読み込んだ設定からoptimizer, schedulerのオブジェクトを生成する機能はないっぽい (hydra.utils.instantiate()
で可能)