Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# Copyright (c) Megvii Inc. All rights reserved. | |
import ast | |
import pprint | |
from abc import ABCMeta, abstractmethod | |
from typing import Dict, List, Tuple | |
from tabulate import tabulate | |
import torch | |
from torch.nn import Module | |
from yolox.utils import LRScheduler | |
class BaseExp(metaclass=ABCMeta): | |
"""Basic class for any experiment.""" | |
def __init__(self): | |
self.seed = None | |
self.output_dir = "./YOLOX_outputs" | |
self.print_interval = 100 | |
self.eval_interval = 10 | |
self.dataset = None | |
def get_model(self) -> Module: | |
pass | |
def get_dataset(self, cache: bool = False, cache_type: str = "ram"): | |
pass | |
def get_data_loader( | |
self, batch_size: int, is_distributed: bool | |
) -> Dict[str, torch.utils.data.DataLoader]: | |
pass | |
def get_optimizer(self, batch_size: int) -> torch.optim.Optimizer: | |
pass | |
def get_lr_scheduler( | |
self, lr: float, iters_per_epoch: int, **kwargs | |
) -> LRScheduler: | |
pass | |
def get_evaluator(self): | |
pass | |
def eval(self, model, evaluator, weights): | |
pass | |
def __repr__(self): | |
table_header = ["keys", "values"] | |
exp_table = [ | |
(str(k), pprint.pformat(v)) | |
for k, v in vars(self).items() | |
if not k.startswith("_") | |
] | |
return tabulate(exp_table, headers=table_header, tablefmt="fancy_grid") | |
def merge(self, cfg_list): | |
assert len(cfg_list) % 2 == 0, f"length must be even, check value here: {cfg_list}" | |
for k, v in zip(cfg_list[0::2], cfg_list[1::2]): | |
# only update value with same key | |
if hasattr(self, k): | |
src_value = getattr(self, k) | |
src_type = type(src_value) | |
# pre-process input if source type is list or tuple | |
if isinstance(src_value, (List, Tuple)): | |
v = v.strip("[]()") | |
v = [t.strip() for t in v.split(",")] | |
# find type of tuple | |
if len(src_value) > 0: | |
src_item_type = type(src_value[0]) | |
v = [src_item_type(t) for t in v] | |
if src_value is not None and src_type != type(v): | |
try: | |
v = src_type(v) | |
except Exception: | |
v = ast.literal_eval(v) | |
setattr(self, k, v) | |