-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
28 lines (20 loc) · 855 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import hydra
import argparse
from omegaconf import OmegaConf, DictConfig
from datamodule import SpeechCommandDataModule
from model import ModelModule
import pytorch_lightning as pl
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Config path")
parser.add_argument("-cp", help="config path") # config path
parser.add_argument("-cn", help="config name") # config name
args = parser.parse_args()
@hydra.main(config_path=args.cp, config_name=args.cn)
def main(cfg: DictConfig):
dm = SpeechCommandDataModule(**cfg.datamodule)
model = ModelModule(**cfg.model)
logger = pl.loggers.tensorboard.TensorBoardLogger(**cfg.logger)
trainer = pl.Trainer(logger=logger, **cfg.trainer)
trainer.fit(model, datamodule=dm)
trainer.test(model, datamodule=dm)
main()