File size: 2,287 Bytes
b1350bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# filter images
from PIL import Image, ImageSequence
import requests
from tqdm import tqdm
import numpy as np
import torch
from transformers import CLIPProcessor, CLIPModel

def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
    converted_len = int(clip_len * frame_sample_rate)
    end_idx = np.random.randint(converted_len, seg_len)
    start_idx = end_idx - converted_len
    indices = np.linspace(start_idx, end_idx, num=clip_len)
    indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
    return indices

def load_frames(image: Image, mode='RGBA'):
    return np.array([
        np.array(frame.convert(mode))
        for frame in ImageSequence.Iterator(image)
    ])

img_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
img_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")



def filter(gifs, input_image): 
    max_cosine = 0.9
    max_gif = []
    
    for gif in tqdm(gifs, total=len(gifs)):
        with Image.open(gif) as im:
            frames = load_frames(im)

        frames = np.array(frames)
        frames = frames[:, :, :, :3]
        frames = np.transpose(frames, (0, 3, 1, 2))[1:]



        image = Image.open(input_image)
        
        
        inputs = img_processor(images=frames, return_tensors="pt", padding=False)
        inputs_base = img_processor(images=image, return_tensors="pt", padding=False)
        
        with torch.no_grad():
            feat_img_base = img_model.get_image_features(pixel_values=inputs_base["pixel_values"])        
            feat_img_vid = img_model.get_image_features(pixel_values=inputs["pixel_values"])
        cos_avg = 0
        avg_score_for_vid = 0
        for i in range(len(feat_img_vid)):
                
            cosine_similarity = torch.nn.functional.cosine_similarity(
                feat_img_base, 
                feat_img_vid[0].unsqueeze(0), 
                dim=1)
            # print(cosine_similarity)
            cos_avg += cosine_similarity.item()
            
        cos_avg /= len(feat_img_vid)
        print("Current cosine similarity: ", cos_avg)
        print("Max cosine similarity: ", max_cosine)
        if cos_avg > max_cosine:
            # max_cosine = cos_avg
            max_gif.append(gif)
    return max_gif