loooooong commited on
Commit
8c1c7b9
1 Parent(s): 1a06c05

add fidelity tips

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -16,7 +16,6 @@ from diffusers.loaders import LoraLoaderMixin
16
  import os
17
  from os.path import join as opj
18
 
19
- # run only once
20
  token = os.getenv("ACCESS_TOKEN")
21
  os.system(f"python -m pip install git+https://{token}@github.com/logn-2024/StableGarment.git")
22
 
@@ -24,7 +23,7 @@ from stablegarment.models import AppearanceEncoderModel,ControlNetModel
24
  from stablegarment.piplines import StableGarmentPipeline,StableGarmentControlNetPipeline
25
 
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
- torch_dtype = torch.float16 if "cuda"==device else torch.float32
28
  height = 512
29
  width = 384
30
 
@@ -37,7 +36,7 @@ garment_encoder = AppearanceEncoderModel.from_pretrained(pretrained_garment_enco
37
  garment_encoder = garment_encoder.to(device=device,dtype=torch_dtype)
38
 
39
  pipeline_t2i = StableGarmentPipeline.from_pretrained(base_model_path, vae=vae, torch_dtype=torch_dtype,).to(device=device) # variant="fp16"
40
- # pipeline = StableDiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V4.0_noVAE", vae=vae, torch_dtype=torch_dtype, variant="fp16").to(device=device)
41
  pipeline_t2i.scheduler = scheduler
42
 
43
  pipeline_tryon = None
@@ -97,13 +96,16 @@ def text2image(prompt,init_image,garment_top,garment_down,style_fidelity=1.):
97
  garment_top = Image.open(garment_top).resize((width,height))
98
  garment_top = transforms.CenterCrop((height,width))(transforms.Resize(max(height, width))(garment_top))
99
 
 
 
100
  garment_images = [garment_top,]
101
  prompt = [prompt,]
102
  cloth_prompt = ["",]
103
  n_prompt = "nsfw, unsaturated, abnormal, unnatural, artifact"
104
  negative_prompt = [n_prompt]
 
105
  images = pipeline_t2i(prompt,negative_prompt=negative_prompt,cloth_prompt=cloth_prompt,height=height,width=width,
106
- num_inference_steps=30,guidance_scale=4,num_images_per_prompt=1,style_fidelity=style_fidelity,
107
  garment_encoder=garment_encoder,garment_image=garment_images,).images
108
  return images[0]
109
 
@@ -170,8 +172,10 @@ with gr.Blocks(css = ".output-image, .input-image, .image-preview {height: 400px
170
  with gr.Row():
171
  t2i_only = gr.Checkbox(label="t2i with garment", info="Only text and garment.", elem_id="t2i_switch", value=True, interactive=False,)
172
  run_button = gr.Button(value="Run")
173
- style_fidelity = gr.Slider(0, 1, value=1, label="fidelity(for t2i)") # , info=""
174
  t2i_only.change(fn=set_mode,inputs=[t2i_only,init_image,prompt],outputs=[init_image,prompt,])
 
 
 
175
  with gr.Column():
176
  gallery = gr.Image()
177
  run_button.click(fn=infer,
 
16
  import os
17
  from os.path import join as opj
18
 
 
19
  token = os.getenv("ACCESS_TOKEN")
20
  os.system(f"python -m pip install git+https://{token}@github.com/logn-2024/StableGarment.git")
21
 
 
23
  from stablegarment.piplines import StableGarmentPipeline,StableGarmentControlNetPipeline
24
 
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
26
+ torch_dtype = torch.float32 if device=="cpu" else torch.float16
27
  height = 512
28
  width = 384
29
 
 
36
  garment_encoder = garment_encoder.to(device=device,dtype=torch_dtype)
37
 
38
  pipeline_t2i = StableGarmentPipeline.from_pretrained(base_model_path, vae=vae, torch_dtype=torch_dtype,).to(device=device) # variant="fp16"
39
+ # pipeline = StableDiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V4.0_noVAE", vae=vae, torch_dtype=torch_dtype).to(device=device)
40
  pipeline_t2i.scheduler = scheduler
41
 
42
  pipeline_tryon = None
 
96
  garment_top = Image.open(garment_top).resize((width,height))
97
  garment_top = transforms.CenterCrop((height,width))(transforms.Resize(max(height, width))(garment_top))
98
 
99
+ # always enable classifier-free-guidance as it is related to garment
100
+ cfg = 4 # if prompt else 0
101
  garment_images = [garment_top,]
102
  prompt = [prompt,]
103
  cloth_prompt = ["",]
104
  n_prompt = "nsfw, unsaturated, abnormal, unnatural, artifact"
105
  negative_prompt = [n_prompt]
106
+
107
  images = pipeline_t2i(prompt,negative_prompt=negative_prompt,cloth_prompt=cloth_prompt,height=height,width=width,
108
+ num_inference_steps=30,guidance_scale=cfg,num_images_per_prompt=1,style_fidelity=style_fidelity,
109
  garment_encoder=garment_encoder,garment_image=garment_images,).images
110
  return images[0]
111
 
 
172
  with gr.Row():
173
  t2i_only = gr.Checkbox(label="t2i with garment", info="Only text and garment.", elem_id="t2i_switch", value=True, interactive=False,)
174
  run_button = gr.Button(value="Run")
 
175
  t2i_only.change(fn=set_mode,inputs=[t2i_only,init_image,prompt],outputs=[init_image,prompt,])
176
+ with gr.Accordion("advance options", open=False):
177
+ gr.Markdown("Garment fidelity control(Tune down it to reduce white edge).")
178
+ style_fidelity = gr.Slider(0, 1, value=1, label="fidelity(only for t2i)") # , info=""
179
  with gr.Column():
180
  gallery = gr.Image()
181
  run_button.click(fn=infer,