MohamedRashad commited on
Commit
cbeb9ba
1 Parent(s): 4123c45

chore: Add Gradio-based image enhancement and captioning functionality

Browse files
Files changed (1) hide show
  1. app.py +92 -0
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
+ from gradio_client import Client, handle_file
5
+ from colorama import Fore, Style
6
+ from diffusers import AutoPipelineForImage2Image
7
+ from PIL import Image
8
+
9
+ joy_client = Client("fancyfeast/joy-caption-alpha-two")
10
+ qwen_client = Client("Qwen/Qwen2.5-72B-Instruct")
11
+
12
+ pipeline = AutoPipelineForImage2Image.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
13
+
14
+ lora_ids = {
15
+ "Realism": "XLabs-AI/flux-RealismLora",
16
+ "Cartoonism": "aleksa-codes/flux-ghibsky-illustration",
17
+ }
18
+
19
+ def load_lora(lora_name):
20
+ print(f"Loading LoRA model: {lora_name}")
21
+ global pipeline
22
+ pipeline.unload_lora_weights()
23
+ pipeline.load_lora_weights(lora_ids[lora_name])
24
+ pipeline.enable_model_cpu_offload()
25
+ print(f"{Fore.GREEN}LoRA model loaded{Style.RESET_ALL}")
26
+
27
+ def describe_image(image_path):
28
+ print(f"Describing image: {image_path}")
29
+ image_description = joy_client.predict(
30
+ input_image=handle_file(image_path),
31
+ caption_type="Descriptive",
32
+ caption_length="long",
33
+ extra_options=[],
34
+ name_input="",
35
+ custom_prompt="",
36
+ api_name="/stream_chat"
37
+ )[-1]
38
+ print(f"{Fore.GREEN}{image_description}{Style.RESET_ALL}")
39
+ return image_description
40
+
41
+ def refine_prompt(image_description):
42
+ print(f"Improving prompt: {image_description}")
43
+ qwen_prompt = f"""This is the description of the image: {image_description}
44
+
45
+ And those some good AI Art Prompts:
46
+ - a cat on a windowsill gazing out at a starry night sky and distant city lights
47
+ - a fisherman casting a line into a peaceful village lake surrounded by quaint cottages
48
+ - cozy mountain cabin covered in snow, with smoke curling from the chimney and a warm, inviting light spilling through the windows
49
+ - Mykonos
50
+ - an orange Lamborghini driving down a hill road at night with a beautiful ocean view in the background, side view, no text
51
+ - a small Yorkie on a windowsill during a snowy winter night, with a warm, cozy glow from inside and soft snowflakes drifting outside
52
+ - serene Japanese garden with a koi pond and a traditional tea house, nestled under a canopy of cherry blossoms in full bloom
53
+ - the most beautiful place in the universe
54
+
55
+ Based on what i gave you, Write a great short AI Art Prompt for me that is based on the image description above (Don't write anything else, just the prompt)
56
+ """
57
+ refined_prompt = qwen_client.predict(
58
+ query=qwen_prompt,
59
+ history=[],
60
+ system="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
61
+ api_name="/model_chat"
62
+ )[1][0][-1]
63
+ print(f"{Fore.GREEN}{refined_prompt}{Style.RESET_ALL}")
64
+ return refined_prompt
65
+
66
+ @spaces.GPU
67
+ def img2img_infer(image_path, image_description):
68
+ pil_image = Image.open(image_path)
69
+ width, height = pil_image.size
70
+ enhanced_image = pipeline(f'GHIBSKY style, {image_description}', image=pil_image).images[0]
71
+ enhanced_image = enhanced_image.resize((width, height))
72
+ return enhanced_image
73
+
74
+
75
+ with gr.Blocks(title="Magnific") as demo:
76
+ with gr.Row():
77
+ with gr.Column():
78
+ image_path = gr.Image(label="Image", type="filepath")
79
+ lora_dropdown = gr.Dropdown(label="LoRA Model", choices=list(lora_ids.keys()), value=None)
80
+ describe_btn = gr.Button(value="Describe Image", variant="primary")
81
+ with gr.Row(equal_height=True):
82
+ image_description = gr.Textbox(label="Image Description", scale=4)
83
+ refine_prompt_btn = gr.Button(value="Refine", variant="primary", scale=1)
84
+ submit_btn = gr.Button(value="Submit", variant="primary")
85
+ enhanced_image = gr.Image(label="Enhanced Image", type="pil")
86
+
87
+ lora_dropdown.change(load_lora, inputs=lora_dropdown)
88
+ refine_prompt_btn.click(refine_prompt, inputs=image_description, outputs=image_description)
89
+ describe_btn.click(describe_image, inputs=image_path, outputs=image_description)
90
+ submit_btn.click(img2img_infer, inputs=[image_path, image_description], outputs=enhanced_image)
91
+
92
+ demo.queue().launch(share=False)