Spaces:
Build error
Build error
File size: 2,992 Bytes
84c806e 6525b03 84c806e 6525b03 84c806e 6525b03 84c806e 7b207f0 84c806e a811816 6525b03 84c806e 2e45025 a811816 6525b03 84c806e 6525b03 84c806e a811816 84c806e 48a1fa8 84c806e 48a1fa8 84c806e 48a1fa8 84c806e 48a1fa8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import os
import jax
import jax.numpy as jnp
import numpy as np
import requests
import streamlit as st
from PIL import Image
from utils import load_model
def split_image(im, num_rows=3, num_cols=3):
im = np.array(im)
row_size = im.shape[0] // num_rows
col_size = im.shape[1] // num_cols
tiles = [
im[row : row + row_size, col : col + col_size]
for row in range(0, num_rows * row_size, row_size)
for col in range(0, num_cols * col_size, col_size)
]
return tiles
def app(model_name):
model, processor = load_model(f"koclip/{model_name}")
st.title("Patch-based Relevance Ranking")
st.markdown(
"""
Given a piece of text, the CLIP model finds the part of an image that best explains the text.
To try it out, you can
1. Upload an image
2. Explain a part of the image in text
which will yield the most relevant image tile from a grid of the image. You can specify how
granular you want to be with your search by specifying the number of rows and columns that
make up the image grid.
---
"""
)
query1 = st.text_input(
"Enter a URL to an image...",
value="https://img.sbs.co.kr/newimg/news/20200823/201463830_1280.jpg",
)
query2 = st.file_uploader("or upload an image...", type=["jpg", "jpeg", "png"])
captions = st.text_input(
"Enter a prompt to query the image.",
value="이건 서울의 경복궁 사진이다.",
)
col1, col2 = st.beta_columns(2)
with col1:
num_rows = st.slider(
"Number of rows", min_value=1, max_value=5, value=3, step=1
)
with col2:
num_cols = st.slider(
"Number of columns", min_value=1, max_value=5, value=3, step=1
)
if st.button("질문 (Query)"):
if not any([query1, query2]):
st.error("Please upload an image or paste an image URL.")
else:
st.markdown("""---""")
with st.spinner("Computing..."):
image_data = (
query2
if query2 is not None
else requests.get(query1, stream=True).raw
)
image = Image.open(image_data)
st.image(image)
images = split_image(image, num_rows, num_cols)
inputs = processor(
text=captions, images=images, return_tensors="jax", padding=True
)
inputs["pixel_values"] = jnp.transpose(
inputs["pixel_values"], axes=[0, 2, 3, 1]
)
outputs = model(**inputs)
probs = jax.nn.softmax(outputs.logits_per_image, axis=0)
for idx, prob in sorted(
enumerate(probs), key=lambda x: x[1], reverse=True
):
st.text(f"Score: {prob[0]:.3f}")
st.image(images[idx])
|