Spaces:
No application file
No application file
File size: 2,105 Bytes
81685e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
from anomalib.data.image.folder import Folder
from anomalib.models import EfficientAd
from anomalib.models.image.efficient_ad.lightning_model import EfficientAdModelSize
from anomalib.data.base.dataset import TaskType
from anomalib.data.utils import TestSplitMode, ValSplitMode
from anomalib.engine import Engine
from cog import Path, BaseModel, Input
import io
import shutil
class TrainingOutput(BaseModel):
weights: Path
dataset_root: Path
pretrained: Path
def train(normal_dir: list[Path] = Input(description="A file containing training normal data"),) -> TrainingOutput:
_normal_dir = Path("normal")
_dataset_dir = Path("dataset")
_dataset_normal_dir = _dataset_dir / _normal_dir
_dataset_normal_dir.mkdir(parents=True, exist_ok=True)
for dir_path in normal_dir:
for file_path in dir_path.iterdir():
if file_path.is_file():
shutil.copy(file_path, str(_dataset_normal_dir))
datamodule = Folder(name="hazelnut_toy", normal_dir=str(_normal_dir), root=str(_dataset_dir), abnormal_dir=None, normal_test_dir=None, mask_dir=None, normal_split_ratio=0.2, extensions=None, train_batch_size=1, eval_batch_size=32, num_workers=8, task=TaskType.SEGMENTATION, image_size=None, transform=None, train_transform=None, eval_transform=None, test_split_mode=TestSplitMode.SYNTHETIC, test_split_ratio=0.2, val_split_mode=ValSplitMode.FROM_TEST, val_split_ratio=0.5, seed=None)
datamodule.setup()
model = EfficientAd(imagenet_dir=_dataset_dir, teacher_out_channels=384, model_size=EfficientAdModelSize.S, lr=0.0001, weight_decay=1e-05, padding=False, pad_maps=True)
engine = Engine()
engine.train(datamodule=datamodule, model=model)
weights_file = "results/EfficientAd/dataset/latest/weights/lightning/model.ckpt"
return TrainingOutput(weights=Path(weights_file), dataset_root=Path(normal_dir), pretrained=Path("pre_trained"))
# anomalib predict --return_predictions false --ckpt_path results/EfficientAd/avatar_rigging/latest/weights/lightning/model.ckpt --config results/EfficientAd/avatar_rigging/latest/config.yaml |