File size: 2,992 Bytes
84c806e
 
 
 
 
6525b03
 
 
84c806e
 
 
6525b03
 
84c806e
6525b03
 
84c806e
7b207f0
 
 
84c806e
 
 
 
 
 
 
a811816
6525b03
 
84c806e
 
2e45025
 
 
 
 
 
 
a811816
 
6525b03
 
84c806e
 
 
6525b03
 
 
84c806e
a811816
84c806e
 
48a1fa8
 
 
 
 
 
 
 
 
 
84c806e
 
 
 
 
48a1fa8
 
 
 
 
 
 
 
 
84c806e
48a1fa8
84c806e
48a1fa8
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os

import jax
import jax.numpy as jnp
import numpy as np
import requests
import streamlit as st
from PIL import Image

from utils import load_model


def split_image(im, num_rows=3, num_cols=3):
    im = np.array(im)
    row_size = im.shape[0] // num_rows
    col_size = im.shape[1] // num_cols
    tiles = [
        im[row : row + row_size, col : col + col_size]
        for row in range(0, num_rows * row_size, row_size)
        for col in range(0, num_cols * col_size, col_size)
    ]
    return tiles


def app(model_name):
    model, processor = load_model(f"koclip/{model_name}")

    st.title("Patch-based Relevance Ranking")
    st.markdown(
        """
        Given a piece of text, the CLIP model finds the part of an image that best explains the text.
        To try it out, you can

        1. Upload an image
        2. Explain a part of the image in text

        which will yield the most relevant image tile from a grid of the image. You can specify how
        granular you want to be with your search by specifying the number of rows and columns that
        make up the image grid.

        ---
        """
    )

    query1 = st.text_input(
        "Enter a URL to an image...",
        value="https://img.sbs.co.kr/newimg/news/20200823/201463830_1280.jpg",
    )
    query2 = st.file_uploader("or upload an image...", type=["jpg", "jpeg", "png"])
    captions = st.text_input(
        "Enter a prompt to query the image.",
        value="이건 서울의 경복궁 사진이다.",
    )

    col1, col2 = st.beta_columns(2)
    with col1:
        num_rows = st.slider(
            "Number of rows", min_value=1, max_value=5, value=3, step=1
        )
    with col2:
        num_cols = st.slider(
            "Number of columns", min_value=1, max_value=5, value=3, step=1
        )

    if st.button("질문 (Query)"):
        if not any([query1, query2]):
            st.error("Please upload an image or paste an image URL.")
        else:
            st.markdown("""---""")
            with st.spinner("Computing..."):
                image_data = (
                    query2
                    if query2 is not None
                    else requests.get(query1, stream=True).raw
                )
                image = Image.open(image_data)
                st.image(image)

                images = split_image(image, num_rows, num_cols)

                inputs = processor(
                    text=captions, images=images, return_tensors="jax", padding=True
                )
                inputs["pixel_values"] = jnp.transpose(
                    inputs["pixel_values"], axes=[0, 2, 3, 1]
                )
                outputs = model(**inputs)
                probs = jax.nn.softmax(outputs.logits_per_image, axis=0)
                for idx, prob in sorted(
                    enumerate(probs), key=lambda x: x[1], reverse=True
                ):
                    st.text(f"Score: {prob[0]:.3f}")
                    st.image(images[idx])