Spaces:
Configuration error
Configuration error
File size: 3,293 Bytes
eb144e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import datetime
import random
from argparse import ArgumentParser
from pathlib import Path
import numpy as np
import torch
import yaml
from tqdm import tqdm
from datasets.video_dataset import AtlasDataset
from models.video_model import VideoModel
from util.atlas_loss import AtlasLoss
from util.util import get_optimizer
from util.video_logger import DataLogger
def train_model(config):
# set seed
seed = config["seed"]
if seed == -1:
seed = np.random.randint(2 ** 32)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
print(f"running with seed: {seed}.")
dataset = AtlasDataset(config)
model = VideoModel(config)
criterion = AtlasLoss(config)
optimizer = get_optimizer(config, model.parameters())
logger = DataLogger(config, dataset)
with tqdm(range(1, config["n_epochs"] + 1)) as tepoch:
for epoch in tepoch:
inputs = dataset[0]
optimizer.zero_grad()
outputs = model(inputs)
losses = criterion(outputs, inputs)
loss = 0.
if config["finetune_foreground"]:
loss += losses["foreground"]["loss"]
elif config["finetune_background"]:
loss += losses["background"]["loss"]
lr = optimizer.param_groups[0]["lr"]
log_data = logger.log_data(epoch, lr, losses, model, dataset)
loss.backward()
optimizer.step()
optimizer.param_groups[0]["lr"] = max(config["min_lr"], config["gamma"] * optimizer.param_groups[0]["lr"])
if config["use_wandb"]:
wandb.log(log_data)
else:
if epoch % config["log_images_freq"] == 0:
logger.save_locally(log_data)
tepoch.set_description(f"Epoch {epoch}")
tepoch.set_postfix(loss=loss.item())
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument(
"--config",
default="./configs/video_config.yaml",
help="Config path",
)
parser.add_argument(
"--example_config",
default="car-turn_winter.yaml",
help="Example config name",
)
args = parser.parse_args()
config_path = args.config
with open(config_path, "r") as f:
config = yaml.safe_load(f)
with open(f"./configs/video_example_configs/{args.example_config}", "r") as f:
example_config = yaml.safe_load(f)
config["example_config"] = args.example_config
config.update(example_config)
run_name = f"-{config['checkpoint_path'].split('/')[-2]}"
if config["use_wandb"]:
import wandb
wandb.init(project=config["wandb_project"], entity=config["wandb_entity"], config=config, name=run_name)
wandb.run.name = str(wandb.run.id) + wandb.run.name
config = dict(wandb.config)
else:
now = datetime.datetime.now()
run_name = f"{now.strftime('%Y-%m-%d_%H-%M-%S')}{run_name}"
path = Path(f"{config['results_folder']}/{run_name}")
path.mkdir(parents=True, exist_ok=True)
with open(path / "config.yaml", "w") as f:
yaml.dump(config, f)
config["results_folder"] = str(path)
train_model(config)
if config["use_wandb"]:
wandb.finish() |