Spaces:
Sleeping
Sleeping
File size: 5,548 Bytes
2c3862a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import streamlit as st
from transformers import AutoModel
from PIL import Image
import torch
import numpy as np
@st.cache_resource
def load_model():
model = AutoModel.from_pretrained("ragavsachdeva/magi", trust_remote_code=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
return model
@st.cache_data
def read_image_as_np_array(image_path):
with open(image_path, "rb") as file:
image = Image.open(file).convert("L").convert("RGB")
image = np.array(image)
return image
@st.cache_data
def predict_detections_and_associations(
image_path,
character_detection_threshold,
panel_detection_threshold,
text_detection_threshold,
character_character_matching_threshold,
text_character_matching_threshold,
):
image = read_image_as_np_array(image_path)
with torch.no_grad():
result = model.predict_detections_and_associations(
[image],
character_detection_threshold=character_detection_threshold,
panel_detection_threshold=panel_detection_threshold,
text_detection_threshold=text_detection_threshold,
character_character_matching_threshold=character_character_matching_threshold,
text_character_matching_threshold=text_character_matching_threshold,
)[0]
return result
@st.cache_data
def predict_ocr(
image_path,
character_detection_threshold,
panel_detection_threshold,
text_detection_threshold,
character_character_matching_threshold,
text_character_matching_threshold,
):
if not generate_transcript:
return
image = read_image_as_np_array(image_path)
result = predict_detections_and_associations(
path_to_image,
character_detection_threshold,
panel_detection_threshold,
text_detection_threshold,
character_character_matching_threshold,
text_character_matching_threshold,
)
text_bboxes_for_all_images = [result["texts"]]
with torch.no_grad():
ocr_results = model.predict_ocr([image], text_bboxes_for_all_images)
return ocr_results
model = load_model()
path_to_image = "/scratch/shared/beegfs/rs/comics/mangas/bakuman/1.0/p_00009.png"
st.markdown("<style>.title{font-size:2em;text-align:center;color:#fff;font-family:'Comic Sans MS',cursive;text-transform:uppercase;letter-spacing:.1em;padding:.5em 0 .2em;background:0 0}.title span{background:-webkit-linear-gradient(45deg,#6495ed,#4169e1);-webkit-background-clip:text;-webkit-text-fill-color:transparent}.subheading{font-size:1.5em;text-align:center;color:#ddd;font-family:'Comic Sans MS',cursive}.affil,.authors{font-size:1em;text-align:center;color:#ddd;font-family:'Comic Sans MS',cursive}.authors{padding-top:1em}</style><div class='title-container'> <div class='title'> The <span>Ma</span>n<span>g</span>a Wh<span>i</span>sperer </div> <div class='subheading'> Automatically Generating Transcriptions for Comics </div> <div class='authors'> Ragav Sachdeva and Andrew Zisserman </div> <div class='affil'> University of Oxford </div></div>", unsafe_allow_html=True)
path_to_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
st.sidebar.markdown("**Mode**")
generate_detections_and_associations = st.sidebar.toggle("Generate detections and associations", True)
generate_transcript = st.sidebar.toggle("Generate transcript (slower)", False)
st.sidebar.markdown("**Hyperparameters**")
input_character_detection_threshold = st.sidebar.slider('Character detection threshold', 0.0, 1.0, 0.30, step=0.01)
input_panel_detection_threshold = st.sidebar.slider('Panel detection threshold', 0.0, 1.0, 0.2, step=0.01)
input_text_detection_threshold = st.sidebar.slider('Text detection threshold', 0.0, 1.0, 0.25, step=0.01)
input_character_character_matching_threshold = st.sidebar.slider('Character-character matching threshold', 0.0, 1.0, 0.7, step=0.01)
input_text_character_matching_threshold = st.sidebar.slider('Text-character matching threshold', 0.0, 1.0, 0.4, step=0.01)
if path_to_image is None:
st.stop()
image = read_image_as_np_array(path_to_image)
st.markdown("**Prediction**")
if generate_detections_and_associations or generate_transcript:
result = predict_detections_and_associations(
path_to_image,
input_character_detection_threshold,
input_panel_detection_threshold,
input_text_detection_threshold,
input_character_character_matching_threshold,
input_text_character_matching_threshold,
)
if generate_transcript:
ocr_results = predict_ocr(
path_to_image,
input_character_detection_threshold,
input_panel_detection_threshold,
input_text_detection_threshold,
input_character_character_matching_threshold,
input_text_character_matching_threshold,
)
if generate_detections_and_associations and generate_transcript:
col1, col2 = st.columns(2)
output = model.visualise_single_image_prediction(image, result)
col1.image(output)
text_bboxes_for_all_images = [result["texts"]]
ocr_results = model.predict_ocr([image], text_bboxes_for_all_images)
transcript = model.generate_transcript_for_single_image(result, ocr_results[0])
col2.text(transcript)
elif generate_detections_and_associations:
output = model.visualise_single_image_prediction(image, result)
st.image(output)
elif generate_transcript:
transcript = model.generate_transcript_for_single_image(result, ocr_results[0])
st.text(transcript)
|