1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
| import argparse import torch import importlib
from utils.io_utils import seed_everything from utils.conf import load_config
def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--config', required=True, help='path to config file') parser.add_argument('--resume', default=None, help='path to the weights to be resumed') parser.add_argument('--stage', default=None, help='train or test or exp') parser.add_argument('--seed', type=int, default=1771, help='seed for initializing training.') args, extras = parser.parse_known_args() return args, extras
if __name__ == "__main__": args, extras = get_args()
if args.seed is not None: seed_everything(args.seed)
config = load_config(args.config, cli_args=extras)
model_config = config["model"] module, cls = model_config["func"].rsplit(".", 1) cls = getattr(importlib.import_module(module, package=None), cls) model = cls(**model_config.get("params", dict()))
dataset_config = config["dataset"] module, cls = dataset_config["func"].rsplit(".", 1) cls = getattr(importlib.import_module(module, package=None), cls) dataset = cls(args.stage, **dataset_config.get("datasets", dict())) dataloader = torch.utils.data.DataLoader(dataset, **dataset_config.get(args.stage, dict()))
system_config = config["system"] module, cls = system_config["func"].rsplit(".", 1) cls = getattr(importlib.import_module(module, package=None), cls) Trainer = cls(model, dataloader, config, args)
if args.stage == "train": Trainer.train() elif args.stage == "test": Trainer.test() elif args.stage == "exp": Trainer.exp()
|