File size: 7,086 Bytes
2e40cec
68e098f
eea7d12
ee3e751
 
2e40cec
bfe0a04
2e40cec
798f2d5
 
 
 
25a9813
798f2d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ddd13d
798f2d5
 
 
 
7b36ba0
798f2d5
 
 
 
4ddd13d
7b36ba0
798f2d5
 
 
 
 
 
 
4ddd13d
798f2d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ddd13d
 
4fb9833
e453372
 
 
 
4ddd13d
 
4fb9833
68e098f
4fb9833
 
aeb7cb4
4fb9833
 
7815fb5
937ea93
eea7d12
 
 
 
 
 
 
a65d535
4fb9833
a65d535
 
4fb9833
a65d535
 
69b2ced
dd41957
 
3ed4f47
4fb9833
69b2ced
 
4fb9833
25a9813
 
4fb9833
25a9813
798f2d5
2069ee0
68e098f
 
 
 
35a15c6
18a36b2
 
 
e453372
9cda3f2
0607a3e
18a36b2
 
ff1294b
0607a3e
9cda3f2
6696be0
68e098f
 
35a15c6
b110321
 
 
 
b2f67f7
b110321
 
 
9cda3f2
2ec47eb
68e098f
 
b110321
 
 
9acd6a4
b110321
 
02a57d7
b110321
02a57d7
2ec47eb
f8ac765
798f2d5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import modal
import gradio as gr 
import numpy as np 
from io import BytesIO
import requests

f = modal.Cls.lookup("casa-interior-hf-v3", "DesignModel")

import requests
from io import BytesIO
from google.cloud import vision
from google.oauth2 import service_account
import PIL 

credentials = {
  "type": "service_account",
  "project_id": "furniture-423815",
  "private_key_id": "be5e481a8e4499c164ed0147b3f024d4ef1f42f3",
  "private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCdy13qrKLk+Lai\nspQgcgKU8YYBOfPdo+FGlodKVb7kTJiEsTN7Ovq69c4S9Hzsf/UNdiEB4wpDIG5m\nBaZrHPBeaZSxmSVhNjctaYR/id06Qvka/Y4PerntUA9ubcVYvZ/ntEpHaL1kVYNe\nATAD0LE0QuQuXPWfBDGvyfsy2hK91D+/WbPCby+pWhh4buRZk3xGku+SGtoTenMP\nzHagPCVNreJD13mrIJu5M1NkB0ZHAdlkOVdRqyxntgcg97krUpace8DM28xB0Pfb\nXk1vaESeUbrcjVt4RDxQAIZwYB4MQ68MiEsuOGZ3O/coXafK89ldMOu+zKlvgloB\ns/JlPtH5AgMBAAECggEABTXpmWXfQKyiWkvHlq0xHuI9XLXBUuq2Fg7DM64SbkdF\nu47+7lUvoaQbjJZweB5PFSVXGHD6/iW4Y4vQ96VGXjXCFF3EZVoFFy2uc4g1yxZa\nU7z295WjxV2BDvJWw5QKb1wtnj9MDr/ApWZoY53c9ib10j6dWUWKDv4eWornNse5\n0ZZYCJV3RtPgEeuf2dyWtFKeAGwiUKYf60l4sBloJbpI1Jedw/0WdlH8WyX5ufuN\nBb9ZWWOmjImr4KGnttLOGg0Id/NZNMJc1i3iz91qWKecregoBuMoNp0AnfclOc1h\nipHXg6zqRZXBDOGPTwBibm8YsR0wWuFx0qCuZNGaYQKBgQDVQW54oneinUL8vVIi\nSdoR8zDrEzje5mgjk68NXn/mUZXhc9toYWblDr5x+PR/LIkjGtUAo706ncV4ysON\nEPB2yrIY1SgTOHP9eW4uTqhQanNr/NgH1/viNXPeQIEx2BnQvcLuORU/V8ZPK+X5\nhRF/xoN9B0Phwxy10SSQZ/iVIQKBgQC9bByD3lvov5ibQn1x57B59zHkq5TPvnXU\ntSFNkWTqus3mmHttJQNP6PcwRiRBaHt2NfKxO9nfIq1rkTaSOMCtsu1N48MF7ccx\niBNnRYMNdu4xmB3JcLyfJ5SZhcO46lJQOrRg0JfemD+BrEgazJi8S7ECwAGemlY1\nrllZnsJJ2QKBgEMxzMdCGgQpHTRZywl2z7mcMSvA8Mh7PREItb22qwI9bsaNJPMs\nzakbDjMHSLLRq5xeFgOPlE5l7BT1fsxyK/KiR5+/elMkFJgnrOn2at57zEaYctF1\n4q4SPaIoHQ1BlFDLmiJJ5kIBPEEyCdKndS4XtNKueVsniWJYtfaybAdBAoGBALU4\n9Z8D4ZKvm2UPG80aCLDnWoiXz2thoIG8OPxpGc+ooMz5HTyyqJSPIc7BjHY3a8cQ\nnfwKcssT9i5vY3JJca28/WQDf9XwQx6UPVwUGOmM2x3/lp/eh9cMmxK18ya6p72y\nLFhjuKhxqHB7TxC0pXugPt2OrP38UnZRM5KWXPMhAoGALFZCVXiDaY/4ay9ATlLs\ndDhS+yX7zJ5vKusT42wAPrFlcu+3eKxGRzFL3c/yNQaFFcpV+TeVsHx2gQ/NRWaL\nu1+99cZ56tTMfajXmRkri+R9wz70awmDx9ReCrl1IMEvPFwtaMMWf6m1xbimfgDv\n3tIueX+ZTxWFRYcI6UGbW7k=\n-----END PRIVATE KEY-----\n",
  "client_email": "furniture-service@furniture-423815.iam.gserviceaccount.com",
  "client_id": "101044092237072973103",
  "auth_uri": "https://accounts.google.com/o/oauth2/auth",
  "token_uri": "https://oauth2.googleapis.com/token",
  "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
  "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/furniture-service%40furniture-423815.iam.gserviceaccount.com",
  "universe_domain": "googleapis.com"
}

class GetProduct:
    
    def __init__(self):
        creds = service_account.Credentials.from_service_account_info(credentials)
        self.client = vision.ImageAnnotatorClient(credentials=creds)
    
    def inference(self, cropped_image) -> list:
        annotations = self.annotate_image(cropped_image)
        selected_images = self.report(annotations)
        
        return selected_images

    def annotate_image(self, image): 
        
        buffer = BytesIO()
        # Convert the image to RGB mode if it is RGBA
        if image.mode == 'RGBA':
            image = image.convert('RGB')
        image.save(buffer, format="JPEG")
        content = buffer.getvalue()
    
        image = vision.Image(content=content)
        web_detection = self.client.web_detection(image=image).web_detection
        return web_detection

    def report(self, annotations) -> list:
        selected_images = []
        if annotations.visually_similar_images: 
            for page in annotations.visually_similar_images:
                try: 
                    response = requests.get(page.url)
                    img = Image.open(BytesIO(response.content))
                    selected_images.append(img)
                except:
                    pass
        return selected_images 

GP = GetProduct()

def casa_ai_run_tab1(image=None, text=None): 
    
    if image is None: 
        print('Please provide image of empty room to design')
        return None

    if text is None: 
        print('Please provide a text prompt')
        return None

    result_image = f.inference.remote("tab1", image, text)
    return result_image

def casa_ai_run_tab2(dict=None, text=None):
    
    image = dict["background"].convert("RGB")
    mask = dict["layers"][0].convert('L')

    if np.sum(np.array(mask)) == 0: 
        mask = None 
        
    if mask is None: 
        print('Please provide a mask over the object you want to generate again.')
        
    if image is None and text is None: 
        print('Please provide context in form of image, text')
        return None
    
    result_image = f.inference.remote("tab2", image, text, mask)
    return result_image

def casa_ai_run_tab3(dict=None):
    ## dict_keys(['background', 'layers', 'composite'])

    selected_crop = dict["composite"]
    
    if selected_crop is None: 
        print('Please provide cropped object')
        return None

    selected_crop = PIL.Image.fromarray(selected_crop)
    
    results = GP.inference(selected_crop)
    return results

with gr.Blocks() as casa:
    title = "Casa-AI Demo"
    description = "A Gradio interface to use CasaAI for virtual staging"

    with gr.Tab("Reimagine"):
        with gr.Row():
            with gr.Column():
                inputs = [
                            gr.Image(sources='upload', type="pil", label="Upload"), 
                            gr.Textbox(label="Room description.")
                        ]
            with gr.Column():
                outputs = [gr.Image(label="Generated room image")]

        
        submit_btn = gr.Button("Generate!")
        submit_btn.click(casa_ai_run_tab1, inputs=inputs, outputs=outputs)

        
    with gr.Tab("Redesign"):
        with gr.Row():
            with gr.Column():
                inputs = [
                            gr.ImageEditor(sources='upload', brush=gr.Brush(colors=["#FFFFFF"]), elem_id="image_upload", type="pil", label="Upload", layers=False, eraser=True, transforms=[]),
                            gr.Textbox(label="Description for redesigning masked object")]
            with gr.Column():
                outputs = [gr.Image(label="Image with new designed object")]
                
        submit_btn = gr.Button("Redesign!")
        submit_btn.click(casa_ai_run_tab2, inputs=inputs, outputs=outputs)

    with gr.Tab("Recommendation"):
        with gr.Row():
            with gr.Column():
                inputs = [
                            gr.ImageEditor(sources='upload', elem_id="image_upload", type="numpy", label="Upload", layers=False, eraser=False, brush=False, transforms=['crop'], crop_size="1:1"),
                            ]
            with gr.Column():
                outputs = [gr.Gallery(label="Similar products")]
                
        submit_btn = gr.Button("Find similar products!")
        submit_btn.click(casa_ai_run_tab3, inputs=inputs, outputs=outputs)

casa.launch()