Spaces:
No application file
No application file
from anomalib.data.image.folder import Folder | |
from anomalib.models import EfficientAd | |
from anomalib.models.image.efficient_ad.lightning_model import EfficientAdModelSize | |
from anomalib.data.base.dataset import TaskType | |
from anomalib.data.utils import TestSplitMode, ValSplitMode | |
from anomalib.engine import Engine | |
from cog import Path, BaseModel, Input | |
import io | |
import shutil | |
class TrainingOutput(BaseModel): | |
weights: Path | |
dataset_root: Path | |
pretrained: Path | |
def train(normal_dir: list[Path] = Input(description="A file containing training normal data"),) -> TrainingOutput: | |
_normal_dir = Path("normal") | |
_dataset_dir = Path("dataset") | |
_dataset_normal_dir = _dataset_dir / _normal_dir | |
_dataset_normal_dir.mkdir(parents=True, exist_ok=True) | |
for dir_path in normal_dir: | |
for file_path in dir_path.iterdir(): | |
if file_path.is_file(): | |
shutil.copy(file_path, str(_dataset_normal_dir)) | |
datamodule = Folder(name="hazelnut_toy", normal_dir=str(_normal_dir), root=str(_dataset_dir), abnormal_dir=None, normal_test_dir=None, mask_dir=None, normal_split_ratio=0.2, extensions=None, train_batch_size=1, eval_batch_size=32, num_workers=8, task=TaskType.SEGMENTATION, image_size=None, transform=None, train_transform=None, eval_transform=None, test_split_mode=TestSplitMode.SYNTHETIC, test_split_ratio=0.2, val_split_mode=ValSplitMode.FROM_TEST, val_split_ratio=0.5, seed=None) | |
datamodule.setup() | |
model = EfficientAd(imagenet_dir=_dataset_dir, teacher_out_channels=384, model_size=EfficientAdModelSize.S, lr=0.0001, weight_decay=1e-05, padding=False, pad_maps=True) | |
engine = Engine() | |
engine.train(datamodule=datamodule, model=model) | |
weights_file = "results/EfficientAd/dataset/latest/weights/lightning/model.ckpt" | |
return TrainingOutput(weights=Path(weights_file), dataset_root=Path(normal_dir), pretrained=Path("pre_trained")) | |
# anomalib predict --return_predictions false --ckpt_path results/EfficientAd/avatar_rigging/latest/weights/lightning/model.ckpt --config results/EfficientAd/avatar_rigging/latest/config.yaml |