File size: 8,332 Bytes
f19da68
 
ffead1e
f19da68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc362b8
f19da68
4cae45e
 
 
 
 
 
 
 
 
fc362b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f66ab33
f9b54be
 
 
f1dff10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc362b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e15cd8
 
 
d2f25e6
7c193c2
 
9e15cd8
6f1239e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc362b8
 
4cae45e
 
559b00c
4cae45e
 
 
 
 
 
 
559b00c
4cae45e
 
 
ffead1e
 
4cae45e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
import gradio as gr
import json
import torch
import time
import random
try:
    # Only on HuggingFace
    import spaces
    is_space_imported = True
except ImportError:
    is_space_imported = False

from tqdm import tqdm
from huggingface_hub import snapshot_download
from models import AudioDiffusion, DDPMScheduler
from audioldm.audio.stft import TacotronSTFT
from audioldm.variational_autoencoder import AutoencoderKL

# Old import
import numpy as np
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from huggingface_hub import hf_hub_download
from gradio_imageslider import ImageSlider
from briarmbg import BriaRMBG
import PIL
from PIL import Image
from typing import Tuple

max_64_bit_int = 2**63 - 1

# Automatic device detection
if torch.cuda.is_available():
    device_type = "cuda"
    device_selection = "cuda:0"
else:
    device_type = "cpu"
    device_selection = "cpu"

class Tango:
    def __init__(self, name = "declare-lab/tango2", device = device_selection):
        
        path = snapshot_download(repo_id = name)
        
        vae_config = json.load(open("{}/vae_config.json".format(path)))
        stft_config = json.load(open("{}/stft_config.json".format(path)))
        main_config = json.load(open("{}/main_config.json".format(path)))
        
        self.vae = AutoencoderKL(**vae_config).to(device)
#        self.stft = TacotronSTFT(**stft_config).to(device)
#        self.model = AudioDiffusion(**main_config).to(device)
#        
#        vae_weights = torch.load("{}/pytorch_model_vae.bin".format(path), map_location = device)
#        stft_weights = torch.load("{}/pytorch_model_stft.bin".format(path), map_location = device)
#        main_weights = torch.load("{}/pytorch_model_main.bin".format(path), map_location = device)
#        
#        self.vae.load_state_dict(vae_weights)
#        self.stft.load_state_dict(stft_weights)
#        self.model.load_state_dict(main_weights)
#
#        print ("Successfully loaded checkpoint from:", name)
#        
#        self.vae.eval()
#        self.stft.eval()
#        self.model.eval()
#        
#        self.scheduler = DDPMScheduler.from_pretrained(main_config["scheduler_name"], subfolder = "scheduler")
        
    def chunks(self, lst, n):
        # Yield successive n-sized chunks from a list
        for i in range(0, len(lst), n):
            yield lst[i:i + n]
        
    def generate(self, prompt, steps = 100, guidance = 3, samples = 1, disable_progress = True):
        # Generate audio for a single prompt string
        with torch.no_grad():
            latents = self.model.inference([prompt], self.scheduler, steps, guidance, samples, disable_progress = disable_progress)
            mel = self.vae.decode_first_stage(latents)
            wave = self.vae.decode_to_waveform(mel)
        return wave
    
    def generate_for_batch(self, prompts, steps = 200, guidance = 3, samples = 1, batch_size = 8, disable_progress = True):
        # Generate audio for a list of prompt strings
        outputs = []
        for k in tqdm(range(0, len(prompts), batch_size)):
            batch = prompts[k: k + batch_size]
            with torch.no_grad():
                latents = self.model.inference(batch, self.scheduler, steps, guidance, samples, disable_progress = disable_progress)
                mel = self.vae.decode_first_stage(latents)
                wave = self.vae.decode_to_waveform(mel)
                outputs += [item for item in wave]
        if samples == 1:
            return outputs
        return list(self.chunks(outputs, samples))

# Initialize TANGO

tango = Tango(device = "cpu")
#tango.vae.to(device_type)
#tango.stft.to(device_type)
#tango.model.to(device_type)

#def update_seed(is_randomize_seed, seed):
#    if is_randomize_seed:
#        return random.randint(0, max_64_bit_int)
#    return seed
#
#def check(
#    prompt,
#    output_number,
#    steps,
#    guidance,
#    is_randomize_seed,
#    seed
#):
#    if prompt is None or prompt == "":
#        raise gr.Error("Please provide a prompt input.")
#    if not output_number in [1, 2, 3]:
#        raise gr.Error("Please ask for 1, 2 or 3 output files.")
#
#def update_output(output_format, output_number):
#    return [
#        gr.update(format = output_format),
#        gr.update(format = output_format, visible = (2 <= output_number)),
#        gr.update(format = output_format, visible = (output_number == 3)),
#        gr.update(visible = False)
#    ]
#
#def text2audio(
#    prompt,
#    output_number,
#    steps,
#    guidance,
#    is_randomize_seed,
#    seed
#):
#    start = time.time()
#
#    if seed is None:
#        seed = random.randint(0, max_64_bit_int)
#
#    random.seed(seed)
#    torch.manual_seed(seed)
#
#    output_wave = tango.generate(prompt, steps, guidance, output_number)
#
#    output_wave_1 = gr.make_waveform((16000, output_wave[0]))
#    output_wave_2 = gr.make_waveform((16000, output_wave[1])) if (2 <= output_number) else None
#    output_wave_3 = gr.make_waveform((16000, output_wave[2])) if (output_number == 3) else None
#
#    end = time.time()
#    secondes = int(end - start)
#    minutes = secondes // 60
#    secondes = secondes - (minutes * 60)
#    hours = minutes // 60
#    minutes = minutes - (hours * 60)
#    return [
#        output_wave_1,
#        output_wave_2,
#        output_wave_3,
#        gr.update(visible = True, value = "Start again to get a different result. The output have been generated in " + ((str(hours) + " h, ") if hours != 0 else "") + ((str(minutes) + " min, ") if hours != 0 or minutes != 0 else "") + str(secondes) + " sec.")
#    ]
#
#if is_space_imported:
#    text2audio = spaces.GPU(text2audio, duration = 420)

# Old code
net=BriaRMBG()
model_path = hf_hub_download("cocktailpeanut/gbmr", 'model.pth')
if torch.cuda.is_available():
    net.load_state_dict(torch.load(model_path))
    net=net.cuda()
    device = "cuda"
elif torch.backends.mps.is_available():
    net.load_state_dict(torch.load(model_path,map_location="mps"))
    net=net.to("mps")
    device = "mps"
else:
    net.load_state_dict(torch.load(model_path,map_location="cpu"))
    device = "cpu"
net.eval() 

    
def resize_image(image):
    image = image.convert('RGB')
    model_input_size = (1024, 1024)
    image = image.resize(model_input_size, Image.BILINEAR)
    return image


def process(image):

    # prepare input
    orig_image = Image.fromarray(image)
    w,h = orig_im_size = orig_image.size
    image = resize_image(orig_image)
    im_np = np.array(image)
    im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
    im_tensor = torch.unsqueeze(im_tensor,0)
    im_tensor = torch.divide(im_tensor,255.0)
    im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
    if device == "cuda":
        im_tensor=im_tensor.cuda()
    elif device == "mps":
        im_tensor=im_tensor.to("mps")

    #inference
    result=net(im_tensor)
    # post process
    result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
    ma = torch.max(result)
    mi = torch.min(result)
    result = (result-mi)/(ma-mi)    
    # image to pil
    im_array = (result*255).cpu().data.numpy().astype(np.uint8)
    pil_im = Image.fromarray(np.squeeze(im_array))
    # paste the mask on the original image
    new_im = Image.new("RGBA", pil_im.size, (0,0,0,0))
    new_im.paste(orig_image, mask=pil_im)

    return new_im

gr.Markdown("## BRIA RMBG 1.4")
gr.HTML('''
  <p style="margin-bottom: 10px; font-size: 94%">
    This is a demo for BRIA RMBG 1.4 that using
    <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone. 
  </p>
''')
title = "Background Removal"
description = r"""Background removal model developed by <a href='https://BRIA.AI' target='_blank'><b>BRIA.AI</b></a>, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.<br> 
For test upload your image and wait. Read more at model card <a href='https://huggingface.co/briaai/RMBG-1.4' target='_blank'><b>briaai/RMBG-1.4</b></a>.<br>
"""
examples = [['./input.jpg'],]
demo = gr.Interface(fn=process,inputs="image", outputs="image", examples=examples, title=title, description=description)

if __name__ == "__main__":
    demo.launch(share=False)