tree-park commited on
Commit
84c806e
1 Parent(s): b21e1da

Add most relevant part

Browse files
Files changed (3) hide show
  1. app.py +2 -1
  2. most_relevant_part.py +80 -0
  3. text2image.py +1 -1
app.py CHANGED
@@ -2,8 +2,9 @@ import streamlit as st
2
 
3
  import image2text
4
  import text2image
 
5
 
6
- PAGES = {"Text to Image": text2image, "Image to Text": image2text}
7
 
8
  st.sidebar.title("Navigation")
9
  model = st.sidebar.selectbox("Choose a model", ["koclip-base", "koclip-large"])
 
2
 
3
  import image2text
4
  import text2image
5
+ import most_relevant_part
6
 
7
+ PAGES = {"Text to Image": text2image, "Image to Text": image2text, "Most Relevant Part of Image": most_relevant_part}
8
 
9
  st.sidebar.title("Navigation")
10
  model = st.sidebar.selectbox("Choose a model", ["koclip-base", "koclip-large"])
most_relevant_part.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import streamlit as st
4
+ from PIL import Image
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+
10
+ from utils import load_model
11
+
12
+ def split_image(im):
13
+ im = np.array(im)
14
+ M = im.shape[0] // 3
15
+ N = im.shape[1] // 3
16
+ tiles = [
17
+ im[x:x + M, y:y + N]
18
+ for x in range(0, im.shape[0], M)
19
+ for y in range(0, im.shape[1], N)
20
+ ]
21
+ return tiles
22
+
23
+
24
+ # def split_image(X):
25
+ # num_rows = X.shape[0] // 224
26
+ # num_cols = X.shape[1] // 224
27
+ # Xc = X[0:num_rows * 224, 0:num_cols * 224, :]
28
+ # patches = []
29
+ # for j in range(num_rows):
30
+ # for i in range(num_cols):
31
+ # patches.append(Xc[j * 224:(j + 1) * 224, i * 224:(i + 1) * 224, :])
32
+ # return patches
33
+
34
+
35
+ def app(model_name):
36
+ model, processor = load_model(f"koclip/{model_name}")
37
+
38
+ st.title("Most Relevant Part of Image")
39
+ st.markdown("""
40
+ Given a piece of text, the CLIP model finds the part of an image that best explains the text.
41
+ To try it out, you can
42
+ 1) Upload an image
43
+ 2) Explain a part of the image in text
44
+ Which will yield the most relevant image tile from a 3x3 grid of the image
45
+ """)
46
+
47
+ query1 = st.text_input(
48
+ "Enter a URL to an image...",
49
+ value="https://img.sbs.co.kr/newimg/news/20200823/201463830_1280.jpg")
50
+ query2 = st.file_uploader("or upload an image...",
51
+ type=["jpg", "jpeg", "png"])
52
+ captions = st.text_input(
53
+ "Enter query to find most relevant part of image ",
54
+ value="이건 서울의 경복궁 사진이다.",
55
+ )
56
+
57
+ if st.button("질문 (Query)"):
58
+ if not any([query1, query2]):
59
+ st.error("Please upload an image or paste an image URL.")
60
+ else:
61
+ image_data = (query2 if query2 is not None else requests.get(
62
+ query1, stream=True).raw)
63
+ image = Image.open(image_data)
64
+ st.image(image)
65
+
66
+ images = split_image(image)
67
+
68
+ inputs = processor(text=captions,
69
+ images=images,
70
+ return_tensors="jax",
71
+ padding=True)
72
+ inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"],
73
+ axes=[0, 2, 3, 1])
74
+ outputs = model(**inputs)
75
+ probs = jax.nn.softmax(outputs.logits_per_image, axis=0)
76
+ for idx, prob in sorted(enumerate(probs),
77
+ key=lambda x: x[1],
78
+ reverse=True):
79
+ st.text(f"Score: {prob[0]:.3f}")
80
+ st.image(images[idx])
text2image.py CHANGED
@@ -40,7 +40,7 @@ def app(model_name):
40
  result_imgs, result_captions = [], []
41
  for file, dist in zip(result_files, dists):
42
  result_imgs.append(plt.imread(os.path.join(images_directory, file)))
43
- result_captions.append("{:s} (유사도: {:.3f})".format(file, 1.0 - dist))
44
 
45
  st.image(result_imgs[:3], caption=result_captions[:3], width=200)
46
  st.image(result_imgs[3:6], caption=result_captions[3:6], width=200)
 
40
  result_imgs, result_captions = [], []
41
  for file, dist in zip(result_files, dists):
42
  result_imgs.append(plt.imread(os.path.join(images_directory, file)))
43
+ result_captions.append("Score: {:.3f}".format(1.0 - dist))
44
 
45
  st.image(result_imgs[:3], caption=result_captions[:3], width=200)
46
  st.image(result_imgs[3:6], caption=result_captions[3:6], width=200)