GenAI-Arena / serve /vote_utils.py
tianleliphoebe's picture
fix bug
8a702cc
raw
history blame
42.7 kB
import datetime
import time
import json
import uuid
import gradio as gr
import regex as re
from pathlib import Path
from .utils import *
from .log_utils import build_logger
from .constants import IMAGE_DIR, VIDEO_DIR
import imageio
ig_logger = build_logger("gradio_web_server_image_generation", "gr_web_image_generation.log") # ig = image generation, loggers for single model direct chat
igm_logger = build_logger("gradio_web_server_image_generation_multi", "gr_web_image_generation_multi.log") # igm = image generation multi, loggers for side-by-side and battle
ie_logger = build_logger("gradio_web_server_image_editing", "gr_web_image_editing.log") # ie = image editing, loggers for single model direct chat
iem_logger = build_logger("gradio_web_server_image_editing_multi", "gr_web_image_editing_multi.log") # iem = image editing multi, loggers for side-by-side and battle
vg_logger = build_logger("gradio_web_server_video_generation", "gr_web_video_generation.log") # vg = video generation, loggers for single model direct chat
vgm_logger = build_logger("gradio_web_server_video_generation_multi", "gr_web_video_generation_multi.log") # vgm = video generation multi, loggers for side-by-side and battle
def vote_last_response_ig(state, vote_type, model_selector, request: gr.Request):
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(time.time(), 4),
"type": vote_type,
"model": model_selector,
"state": state.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
with open(output_file, 'w') as f:
state.output.save(f, 'JPEG')
save_image_file_on_log_server(output_file)
def vote_last_response_igm(states, vote_type, model_selectors, request: gr.Request):
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(time.time(), 4),
"type": vote_type,
"models": [x for x in model_selectors],
"states": [x.dict() for x in states],
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for state in states:
output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
with open(output_file, 'w') as f:
state.output.save(f, 'JPEG')
save_image_file_on_log_server(output_file)
def vote_last_response_ie(state, vote_type, model_selector, request: gr.Request):
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(time.time(), 4),
"type": vote_type,
"model": model_selector,
"state": state.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
output_file = f'{IMAGE_DIR}/edition/{state.conv_id}.jpg'
source_file = f'{IMAGE_DIR}/edition/{state.conv_id}_source.jpg'
with open(output_file, 'w') as f:
state.output.save(f, 'JPEG')
with open(source_file, 'w') as sf:
state.source_image.save(sf, 'JPEG')
save_image_file_on_log_server(output_file)
save_image_file_on_log_server(source_file)
def vote_last_response_iem(states, vote_type, model_selectors, request: gr.Request):
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(time.time(), 4),
"type": vote_type,
"models": [x for x in model_selectors],
"states": [x.dict() for x in states],
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for state in states:
output_file = f'{IMAGE_DIR}/edition/{state.conv_id}.jpg'
source_file = f'{IMAGE_DIR}/edition/{state.conv_id}_source.jpg'
with open(output_file, 'w') as f:
state.output.save(f, 'JPEG')
with open(source_file, 'w') as sf:
state.source_image.save(sf, 'JPEG')
save_image_file_on_log_server(output_file)
save_image_file_on_log_server(source_file)
def vote_last_response_vg(state, vote_type, model_selector, request: gr.Request):
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(time.time(), 4),
"type": vote_type,
"model": model_selector,
"state": state.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
os.makedirs(os.path.dirname(output_file), exist_ok=True)
if state.model_name.startswith('fal'):
r = requests.get(state.output)
with open(output_file, 'wb') as outfile:
outfile.write(r.content)
else:
print("======== video shape: ========")
print(state.output.shape)
imageio.mimwrite(output_file, state.output, fps=8, quality=9)
save_video_file_on_log_server(output_file)
def vote_last_response_vgm(states, vote_type, model_selectors, request: gr.Request):
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(time.time(), 4),
"type": vote_type,
"models": [x for x in model_selectors],
"states": [x.dict() for x in states],
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for state in states:
output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
os.makedirs(os.path.dirname(output_file), exist_ok=True)
if state.model_name.startswith('fal'):
r = requests.get(state.output)
with open(output_file, 'wb') as outfile:
outfile.write(r.content)
else:
print("======== video shape: ========")
print(state.output.shape)
imageio.mimwrite(output_file, state.output, fps=8, quality=9)
save_video_file_on_log_server(output_file)
## Image Generation (IG) Single Model Direct Chat
def upvote_last_response_ig(state, model_selector, request: gr.Request):
ip = get_ip(request)
ig_logger.info(f"upvote. ip: {ip}")
vote_last_response_ig(state, "upvote", model_selector, request)
return ("",) + (disable_btn,) * 3
def downvote_last_response_ig(state, model_selector, request: gr.Request):
ip = get_ip(request)
ig_logger.info(f"downvote. ip: {ip}")
vote_last_response_ig(state, "downvote", model_selector, request)
return ("",) + (disable_btn,) * 3
def flag_last_response_ig(state, model_selector, request: gr.Request):
ip = get_ip(request)
ig_logger.info(f"flag. ip: {ip}")
vote_last_response_ig(state, "flag", model_selector, request)
return ("",) + (disable_btn,) * 3
## Image Generation Multi (IGM) Side-by-Side and Battle
def leftvote_last_response_igm(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
igm_logger.info(f"leftvote (named). ip: {get_ip(request)}")
vote_last_response_igm(
[state0, state1], "leftvote", [model_selector0, model_selector1], request
)
if model_selector0 == "":
return ("",) + (disable_btn,) * 4 + (
gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
return ("",) + (disable_btn,) * 4 + (gr.Markdown(state0.model_name, visible=True),
gr.Markdown(state1.model_name, visible=True))
def rightvote_last_response_igm(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
igm_logger.info(f"rightvote (named). ip: {get_ip(request)}")
vote_last_response_igm(
[state0, state1], "rightvote", [model_selector0, model_selector1], request
)
print(model_selector0)
if model_selector0 == "":
return ("",) + (disable_btn,) * 4 + (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
return ("",) + (disable_btn,) * 4 + (gr.Markdown(state0.model_name, visible=True),
gr.Markdown(state1.model_name, visible=True))
def tievote_last_response_igm(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
igm_logger.info(f"tievote (named). ip: {get_ip(request)}")
vote_last_response_igm(
[state0, state1], "tievote", [model_selector0, model_selector1], request
)
if model_selector0 == "":
return ("",) + (disable_btn,) * 4 + (
gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
return ("",) + (disable_btn,) * 4 + (gr.Markdown(state0.model_name, visible=True),
gr.Markdown(state1.model_name, visible=True))
def bothbad_vote_last_response_igm(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
igm_logger.info(f"bothbad_vote (named). ip: {get_ip(request)}")
vote_last_response_igm(
[state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
)
if model_selector0 == "":
return ("",) + (disable_btn,) * 4 + (
gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
return ("",) + (disable_btn,) * 4 + (gr.Markdown(state0.model_name, visible=True),
gr.Markdown(state1.model_name, visible=True))
## Image Editing (IE) Single Model Direct Chat
def upvote_last_response_ie(state, model_selector, request: gr.Request):
ip = get_ip(request)
ie_logger.info(f"upvote. ip: {ip}")
vote_last_response_ie(state, "upvote", model_selector, request)
return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3
def downvote_last_response_ie(state, model_selector, request: gr.Request):
ip = get_ip(request)
ie_logger.info(f"downvote. ip: {ip}")
vote_last_response_ie(state, "downvote", model_selector, request)
return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3
def flag_last_response_ie(state, model_selector, request: gr.Request):
ip = get_ip(request)
ie_logger.info(f"flag. ip: {ip}")
vote_last_response_ie(state, "flag", model_selector, request)
return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3
## Image Editing Multi (IEM) Side-by-Side and Battle
def leftvote_last_response_iem(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
iem_logger.info(f"leftvote (anony). ip: {get_ip(request)}")
vote_last_response_iem(
[state0, state1], "leftvote", [model_selector0, model_selector1], request
)
# names = (
# "### Model A: " + state0.model_name,
# "### Model B: " + state1.model_name,
# )
# names = (state0.model_name, state1.model_name)
if model_selector0 == "":
names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
names = (gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False))
return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4
def rightvote_last_response_iem(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
iem_logger.info(f"rightvote (anony). ip: {get_ip(request)}")
vote_last_response_iem(
[state0, state1], "rightvote", [model_selector0, model_selector1], request
)
# names = (
# "### Model A: " + state0.model_name,
# "### Model B: " + state1.model_name,
# )
if model_selector0 == "":
names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
names = (gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False))
return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4
def tievote_last_response_iem(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
iem_logger.info(f"tievote (anony). ip: {get_ip(request)}")
vote_last_response_iem(
[state0, state1], "tievote", [model_selector0, model_selector1], request
)
if model_selector0 == "":
names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
names = (gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False))
return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4
def bothbad_vote_last_response_iem(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
iem_logger.info(f"bothbad_vote (anony). ip: {get_ip(request)}")
vote_last_response_iem(
[state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
)
if model_selector0 == "":
names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
names = (gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False))
return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4
## Video Generation (VG) Single Model Direct Chat
def upvote_last_response_vg(state, model_selector, request: gr.Request):
ip = get_ip(request)
vg_logger.info(f"upvote. ip: {ip}")
vote_last_response_vg(state, "upvote", model_selector, request)
return ("",) + (disable_btn,) * 3
def downvote_last_response_vg(state, model_selector, request: gr.Request):
ip = get_ip(request)
vg_logger.info(f"downvote. ip: {ip}")
vote_last_response_vg(state, "downvote", model_selector, request)
return ("",) + (disable_btn,) * 3
def flag_last_response_vg(state, model_selector, request: gr.Request):
ip = get_ip(request)
vg_logger.info(f"flag. ip: {ip}")
vote_last_response_vg(state, "flag", model_selector, request)
return ("",) + (disable_btn,) * 3
## Image Generation Multi (IGM) Side-by-Side and Battle
def leftvote_last_response_vgm(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
vgm_logger.info(f"leftvote (named). ip: {get_ip(request)}")
vote_last_response_vgm(
[state0, state1], "leftvote", [model_selector0, model_selector1], request
)
if model_selector0 == "":
return ("",) + (disable_btn,) * 4 + (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
return ("",) + (disable_btn,) * 4 + (
gr.Markdown(state0.model_name, visible=False),
gr.Markdown(state1.model_name, visible=False))
def rightvote_last_response_vgm(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
vgm_logger.info(f"rightvote (named). ip: {get_ip(request)}")
vote_last_response_vgm(
[state0, state1], "rightvote", [model_selector0, model_selector1], request
)
if model_selector0 == "":
return ("",) + (disable_btn,) * 4 + (
gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
return ("",) + (disable_btn,) * 4 + (
gr.Markdown(state0.model_name, visible=False),
gr.Markdown(state1.model_name, visible=False))
def tievote_last_response_vgm(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
vgm_logger.info(f"tievote (named). ip: {get_ip(request)}")
vote_last_response_vgm(
[state0, state1], "tievote", [model_selector0, model_selector1], request
)
if model_selector0 == "":
return ("",) + (disable_btn,) * 4 + (
gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
return ("",) + (disable_btn,) * 4 + (
gr.Markdown(state0.model_name, visible=False),
gr.Markdown(state1.model_name, visible=False))
def bothbad_vote_last_response_vgm(
state0, state1, model_selector0, model_selector1, request: gr.Request
):
vgm_logger.info(f"bothbad_vote (named). ip: {get_ip(request)}")
vote_last_response_vgm(
[state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
)
if model_selector0 == "":
return ("",) + (disable_btn,) * 4 + (
gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True),
gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True))
else:
return ("",) + (disable_btn,) * 4 + (
gr.Markdown(state0.model_name, visible=False),
gr.Markdown(state1.model_name, visible=False))
share_js = """
function (a, b, c, d) {
const captureElement = document.querySelector('#share-region-named');
html2canvas(captureElement)
.then(canvas => {
canvas.style.display = 'none'
document.body.appendChild(canvas)
return canvas
})
.then(canvas => {
const image = canvas.toDataURL('image/png')
const a = document.createElement('a')
a.setAttribute('download', 'chatbot-arena.png')
a.setAttribute('href', image)
a.click()
canvas.remove()
});
return [a, b, c, d];
}
"""
def share_click_igm(state0, state1, model_selector0, model_selector1, request: gr.Request):
igm_logger.info(f"share (anony). ip: {get_ip(request)}")
if state0 is not None and state1 is not None:
vote_last_response_igm(
[state0, state1], "share", [model_selector0, model_selector1], request
)
def share_click_iem(state0, state1, model_selector0, model_selector1, request: gr.Request):
iem_logger.info(f"share (anony). ip: {get_ip(request)}")
if state0 is not None and state1 is not None:
vote_last_response_iem(
[state0, state1], "share", [model_selector0, model_selector1], request
)
## All Generation Gradio Interface
class ImageStateIG:
def __init__(self, model_name):
self.conv_id = uuid.uuid4().hex
self.model_name = model_name
self.prompt = None
self.output = None
def dict(self):
base = {
"conv_id": self.conv_id,
"model_name": self.model_name,
"prompt": self.prompt
}
return base
class ImageStateIE:
def __init__(self, model_name):
self.conv_id = uuid.uuid4().hex
self.model_name = model_name
self.source_prompt = None
self.target_prompt = None
self.instruct_prompt = None
self.source_image = None
self.output = None
def dict(self):
base = {
"conv_id": self.conv_id,
"model_name": self.model_name,
"source_prompt": self.source_prompt,
"target_prompt": self.target_prompt,
"instruct_prompt": self.instruct_prompt
}
return base
class VideoStateVG:
def __init__(self, model_name):
self.conv_id = uuid.uuid4().hex
self.model_name = model_name
self.prompt = None
self.output = None
def dict(self):
base = {
"conv_id": self.conv_id,
"model_name": self.model_name,
"prompt": self.prompt
}
return base
def generate_ig(gen_func, state, text, model_name, request: gr.Request):
if not text:
raise gr.Warning("Prompt cannot be empty.")
if not model_name:
raise gr.Warning("Model name cannot be empty.")
if state is None:
state = ImageStateIG(model_name)
ip = get_ip(request)
ig_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
generated_image = gen_func(text, model_name)
state.prompt = text
state.output = generated_image
state.model_name = model_name
yield state, generated_image
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, 'w') as f:
state.output.save(f, 'JPEG')
save_image_file_on_log_server(output_file)
def generate_igm(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request):
if not text:
raise gr.Warning("Prompt cannot be empty.")
if not model_name0:
raise gr.Warning("Model name A cannot be empty.")
if not model_name1:
raise gr.Warning("Model name B cannot be empty.")
if state0 is None:
state0 = ImageStateIG(model_name0)
if state1 is None:
state1 = ImageStateIG(model_name1)
ip = get_ip(request)
igm_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
# Remove ### Model (A|B): from model name
model_name0 = re.sub(r"### Model A: ", "", model_name0)
model_name1 = re.sub(r"### Model B: ", "", model_name1)
generated_image0, generated_image1 = gen_func(text, model_name0, model_name1)
state0.prompt = text
state1.prompt = text
state0.output = generated_image0
state1.output = generated_image1
state0.model_name = model_name0
state1.model_name = model_name1
yield state0, state1, generated_image0, generated_image1
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name0,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state0.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name1,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state1.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for i, state in enumerate([state0, state1]):
output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, 'w') as f:
state.output.save(f, 'JPEG')
save_image_file_on_log_server(output_file)
def generate_igm_annoy(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request):
if not text:
raise gr.Warning("Prompt cannot be empty.")
if state0 is None:
state0 = ImageStateIG(model_name0)
if state1 is None:
state1 = ImageStateIG(model_name1)
ip = get_ip(request)
igm_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
model_name0 = ""
model_name1 = ""
generated_image0, generated_image1, model_name0, model_name1 = gen_func(text, model_name0, model_name1)
state0.prompt = text
state1.prompt = text
state0.output = generated_image0
state1.output = generated_image1
state0.model_name = model_name0
state1.model_name = model_name1
yield state0, state1, generated_image0, generated_image1, \
gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False)
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name0,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state0.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name1,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state1.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for i, state in enumerate([state0, state1]):
output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg'
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, 'w') as f:
state.output.save(f, 'JPEG')
save_image_file_on_log_server(output_file)
def generate_ie(gen_func, state, source_text, target_text, instruct_text, source_image, model_name, request: gr.Request):
if not source_text:
raise gr.Warning("Source prompt cannot be empty.")
if not target_text:
raise gr.Warning("Target prompt cannot be empty.")
if not instruct_text:
raise gr.Warning("Instruction prompt cannot be empty.")
if not source_image:
raise gr.Warning("Source image cannot be empty.")
if not model_name:
raise gr.Warning("Model name cannot be empty.")
if state is None:
state = ImageStateIE(model_name)
ip = get_ip(request)
ig_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
generated_image = gen_func(source_text, target_text, instruct_text, source_image, model_name)
state.source_prompt = source_text
state.target_prompt = target_text
state.instruct_prompt = instruct_text
state.source_image = source_image
state.output = generated_image
state.model_name = model_name
yield state, generated_image
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg'
os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
with open(src_img_file, 'w') as f:
state.source_image.save(f, 'JPEG')
output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg'
with open(output_file, 'w') as f:
state.output.save(f, 'JPEG')
save_image_file_on_log_server(src_img_file)
save_image_file_on_log_server(output_file)
def generate_iem(gen_func, state0, state1, source_text, target_text, instruct_text, source_image, model_name0, model_name1, request: gr.Request):
if not source_text:
raise gr.Warning("Source prompt cannot be empty.")
if not target_text:
raise gr.Warning("Target prompt cannot be empty.")
if not instruct_text:
raise gr.Warning("Instruction prompt cannot be empty.")
if not source_image:
raise gr.Warning("Source image cannot be empty.")
if not model_name0:
raise gr.Warning("Model name A cannot be empty.")
if not model_name1:
raise gr.Warning("Model name B cannot be empty.")
if state0 is None:
state0 = ImageStateIE(model_name0)
if state1 is None:
state1 = ImageStateIE(model_name1)
ip = get_ip(request)
igm_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
model_name0 = re.sub(r"### Model A: ", "", model_name0)
model_name1 = re.sub(r"### Model B: ", "", model_name1)
generated_image0, generated_image1 = gen_func(source_text, target_text, instruct_text, source_image, model_name0, model_name1)
state0.source_prompt = source_text
state0.target_prompt = target_text
state0.instruct_prompt = instruct_text
state0.source_image = source_image
state0.output = generated_image0
state0.model_name = model_name0
state1.source_prompt = source_text
state1.target_prompt = target_text
state1.instruct_prompt = instruct_text
state1.source_image = source_image
state1.output = generated_image1
state1.model_name = model_name1
yield state0, state1, generated_image0, generated_image1
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name0,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state0.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name1,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state1.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for i, state in enumerate([state0, state1]):
src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg'
os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
with open(src_img_file, 'w') as f:
state.source_image.save(f, 'JPEG')
output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg'
with open(output_file, 'w') as f:
state.output.save(f, 'JPEG')
save_image_file_on_log_server(src_img_file)
save_image_file_on_log_server(output_file)
def generate_iem_annoy(gen_func, state0, state1, source_text, target_text, instruct_text, source_image, model_name0, model_name1, request: gr.Request):
if not source_text:
raise gr.Warning("Source prompt cannot be empty.")
if not target_text:
raise gr.Warning("Target prompt cannot be empty.")
if not instruct_text:
raise gr.Warning("Instruction prompt cannot be empty.")
if not source_image:
raise gr.Warning("Source image cannot be empty.")
if state0 is None:
state0 = ImageStateIE(model_name0)
if state1 is None:
state1 = ImageStateIE(model_name1)
ip = get_ip(request)
igm_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
model_name0 = ""
model_name1 = ""
generated_image0, generated_image1, model_name0, model_name1 = gen_func(source_text, target_text, instruct_text, source_image, model_name0, model_name1)
state0.source_prompt = source_text
state0.target_prompt = target_text
state0.instruct_prompt = instruct_text
state0.source_image = source_image
state0.output = generated_image0
state0.model_name = model_name0
state1.source_prompt = source_text
state1.target_prompt = target_text
state1.instruct_prompt = instruct_text
state1.source_image = source_image
state1.output = generated_image1
state1.model_name = model_name1
yield state0, state1, generated_image0, generated_image1, \
gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False)
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name0,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state0.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name1,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state1.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for i, state in enumerate([state0, state1]):
src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg'
os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
with open(src_img_file, 'w') as f:
state.source_image.save(f, 'JPEG')
output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg'
with open(output_file, 'w') as f:
state.output.save(f, 'JPEG')
save_image_file_on_log_server(src_img_file)
save_image_file_on_log_server(output_file)
def generate_vg(gen_func, state, text, model_name, request: gr.Request):
if not text:
raise gr.Warning("Prompt cannot be empty.")
if not model_name:
raise gr.Warning("Model name cannot be empty.")
if state is None:
state = VideoStateVG(model_name)
ip = get_ip(request)
vg_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
generated_video = gen_func(text, model_name)
state.prompt = text
state.output = generated_video
state.model_name = model_name
# yield state, generated_video
finish_tstamp = time.time()
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
os.makedirs(os.path.dirname(output_file), exist_ok=True)
if model_name.startswith('fal'):
r = requests.get(state.output)
with open(output_file, 'wb') as outfile:
outfile.write(r.content)
else:
print("======== video shape: ========")
print(state.output.shape)
imageio.mimwrite(output_file, state.output, fps=8, quality=9)
save_video_file_on_log_server(output_file)
yield state, output_file
def generate_vgm(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request):
if not text:
raise gr.Warning("Prompt cannot be empty.")
if not model_name0:
raise gr.Warning("Model name A cannot be empty.")
if not model_name1:
raise gr.Warning("Model name B cannot be empty.")
if state0 is None:
state0 = VideoStateVG(model_name0)
if state1 is None:
state1 = VideoStateVG(model_name1)
ip = get_ip(request)
igm_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
# Remove ### Model (A|B): from model name
model_name0 = re.sub(r"### Model A: ", "", model_name0)
model_name1 = re.sub(r"### Model B: ", "", model_name1)
generated_video0, generated_video1 = gen_func(text, model_name0, model_name1)
state0.prompt = text
state1.prompt = text
state0.output = generated_video0
state1.output = generated_video1
state0.model_name = model_name0
state1.model_name = model_name1
# yield state0, state1, generated_video0, generated_video1
print("====== model name =========")
print(state0.model_name)
print(state1.model_name)
finish_tstamp = time.time()
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name0,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state0.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name1,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state1.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for i, state in enumerate([state0, state1]):
output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
os.makedirs(os.path.dirname(output_file), exist_ok=True)
print(state.model_name)
if state.model_name.startswith('fal'):
r = requests.get(state.output)
with open(output_file, 'wb') as outfile:
outfile.write(r.content)
else:
print("======== video shape: ========")
print(state.output)
print(state.output.shape)
imageio.mimwrite(output_file, state.output, fps=8, quality=9)
save_video_file_on_log_server(output_file)
yield state0, state1, f'{VIDEO_DIR}/generation/{state0.conv_id}.mp4', f'{VIDEO_DIR}/generation/{state1.conv_id}.mp4'
def generate_vgm_annoy(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request):
if not text:
raise gr.Warning("Prompt cannot be empty.")
if state0 is None:
state0 = VideoStateVG(model_name0)
if state1 is None:
state1 = VideoStateVG(model_name1)
ip = get_ip(request)
vgm_logger.info(f"generate. ip: {ip}")
start_tstamp = time.time()
model_name0 = ""
model_name1 = ""
generated_video0, generated_video1, model_name0, model_name1 = gen_func(text, model_name0, model_name1)
state0.prompt = text
state1.prompt = text
state0.output = generated_video0
state1.output = generated_video1
state0.model_name = model_name0
state1.model_name = model_name1
# yield state0, state1, generated_video0, generated_video1, \
# gr.Markdown(f"### Model A: {model_name0}"), gr.Markdown(f"### Model B: {model_name1}")
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name0,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state0.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name1,
"gen_params": {},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state1.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
for i, state in enumerate([state0, state1]):
output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4'
os.makedirs(os.path.dirname(output_file), exist_ok=True)
if state.model_name.startswith('fal'):
r = requests.get(state.output)
with open(output_file, 'wb') as outfile:
outfile.write(r.content)
else:
print("======== video shape: ========")
print(state.output.shape)
imageio.mimwrite(output_file, state.output, fps=8, quality=9)
save_video_file_on_log_server(output_file)
yield state0, state1, f'{VIDEO_DIR}/generation/{state0.conv_id}.mp4', f'{VIDEO_DIR}/generation/{state1.conv_id}.mp4', \
gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False)