File size: 340 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
import pytest
import torch
from lzero.model import ImageTransforms


@pytest.mark.unittest
def test_image_transform():
    img = torch.rand((4, 3, 96, 96))
    transform = ImageTransforms(['shift', 'intensity'])
    processed_img = transform.transform(img)
    assert img.shape == (4, 3, 96, 96)
    assert not (img == processed_img).all()