Spaces:
Running
on
Zero
Running
on
Zero
DongfuJiang
commited on
Commit
•
350d553
1
Parent(s):
625938c
update
Browse files- model/model_manager.py +1 -9
- model/models/imagenhub_models.py +22 -0
- serve/constants.py +1 -0
- serve/log_server.py +19 -2
- serve/log_utils.py +30 -2
- serve/utils.py +9 -1
model/model_manager.py
CHANGED
@@ -3,7 +3,6 @@ import random
|
|
3 |
import gradio as gr
|
4 |
import requests
|
5 |
import io, base64, json
|
6 |
-
import spaces
|
7 |
from PIL import Image
|
8 |
from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, load_pipeline
|
9 |
|
@@ -21,7 +20,6 @@ class ModelManager:
|
|
21 |
pipe = self.loaded_models[model_name]
|
22 |
return pipe
|
23 |
|
24 |
-
@spaces.GPU(duration=60)
|
25 |
def generate_image_ig(self, prompt, model_name):
|
26 |
pipe = self.load_model_pipe(model_name)
|
27 |
result = pipe(prompt=prompt)
|
@@ -51,15 +49,9 @@ class ModelManager:
|
|
51 |
results.append(result)
|
52 |
return results[0], results[1]
|
53 |
|
54 |
-
@spaces.GPU(duration=150)
|
55 |
def generate_image_ie(self, textbox_source, textbox_target, textbox_instruct, source_image, model_name):
|
56 |
pipe = self.load_model_pipe(model_name)
|
57 |
-
|
58 |
-
result = pipe(src_image = source_image, src_prompt = textbox_source, target_prompt = textbox_target, instruct_prompt = textbox_instruct, num_inversion_steps=100)
|
59 |
-
elif 'Prompt2prompt' in model_name:
|
60 |
-
result = pipe(src_image = source_image, src_prompt = textbox_source, target_prompt = textbox_target, instruct_prompt = textbox_instruct, num_inner_steps=5)
|
61 |
-
else:
|
62 |
-
result = pipe(src_image = source_image, src_prompt = textbox_source, target_prompt = textbox_target, instruct_prompt = textbox_instruct)
|
63 |
return result
|
64 |
|
65 |
def generate_image_ie_parallel(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B):
|
|
|
3 |
import gradio as gr
|
4 |
import requests
|
5 |
import io, base64, json
|
|
|
6 |
from PIL import Image
|
7 |
from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, load_pipeline
|
8 |
|
|
|
20 |
pipe = self.loaded_models[model_name]
|
21 |
return pipe
|
22 |
|
|
|
23 |
def generate_image_ig(self, prompt, model_name):
|
24 |
pipe = self.load_model_pipe(model_name)
|
25 |
result = pipe(prompt=prompt)
|
|
|
49 |
results.append(result)
|
50 |
return results[0], results[1]
|
51 |
|
|
|
52 |
def generate_image_ie(self, textbox_source, textbox_target, textbox_instruct, source_image, model_name):
|
53 |
pipe = self.load_model_pipe(model_name)
|
54 |
+
result = pipe(src_image = source_image, src_prompt = textbox_source, target_prompt = textbox_target, instruct_prompt = textbox_instruct)
|
|
|
|
|
|
|
|
|
|
|
55 |
return result
|
56 |
|
57 |
def generate_image_ie_parallel(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B):
|
model/models/imagenhub_models.py
CHANGED
@@ -7,5 +7,27 @@ class ImagenHubModel():
|
|
7 |
def __call__(self, *args, **kwargs):
|
8 |
return self.model.infer_one_image(*args, **kwargs)
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
def load_imagenhub_model(model_name, model_type=None):
|
|
|
|
|
|
|
|
|
11 |
return ImagenHubModel(model_name)
|
|
|
7 |
def __call__(self, *args, **kwargs):
|
8 |
return self.model.infer_one_image(*args, **kwargs)
|
9 |
|
10 |
+
class PNP(ImagenHubModel):
|
11 |
+
def __init__(self):
|
12 |
+
super().__init__('PNP')
|
13 |
+
|
14 |
+
def __call__(self, *args, **kwargs):
|
15 |
+
if "num_inversion_steps" not in kwargs:
|
16 |
+
kwargs["num_inversion_steps"] = 100
|
17 |
+
return super().__call__(*args, **kwargs)
|
18 |
+
|
19 |
+
class Prompt2prompt(ImagenHubModel):
|
20 |
+
def __init__(self):
|
21 |
+
super().__init__('Prompt2prompt')
|
22 |
+
|
23 |
+
def __call__(self, *args, **kwargs):
|
24 |
+
if "num_inner_steps" not in kwargs:
|
25 |
+
kwargs["num_inner_steps"] = 5
|
26 |
+
return super().__call__(*args, **kwargs)
|
27 |
+
|
28 |
def load_imagenhub_model(model_name, model_type=None):
|
29 |
+
if model_name == 'PNP':
|
30 |
+
return PNP()
|
31 |
+
if model_name == 'Prompt2prompt':
|
32 |
+
return Prompt2prompt()
|
33 |
return ImagenHubModel(model_name)
|
serve/constants.py
CHANGED
@@ -13,4 +13,5 @@ LOG_SERVER_ADDR = os.getenv("LOG_SERVER_ADDR", f"{LOG_SERVER}/{LOG_SERVER_SUBDOA
|
|
13 |
# LOG SERVER API ENDPOINTS
|
14 |
APPEND_JSON = "append_json"
|
15 |
SAVE_IMAGE = "save_image"
|
|
|
16 |
|
|
|
13 |
# LOG SERVER API ENDPOINTS
|
14 |
APPEND_JSON = "append_json"
|
15 |
SAVE_IMAGE = "save_image"
|
16 |
+
SAVE_LOG = "save_log"
|
17 |
|
serve/log_server.py
CHANGED
@@ -4,9 +4,9 @@ import json
|
|
4 |
import os
|
5 |
import aiofiles
|
6 |
from .log_utils import build_logger
|
7 |
-
from .constants import LOG_SERVER_SUBDOAMIN, APPEND_JSON, SAVE_IMAGE
|
8 |
|
9 |
-
logger = build_logger("log_server", "log_server.log")
|
10 |
|
11 |
app = APIRouter(prefix=f"/{LOG_SERVER_SUBDOAMIN}")
|
12 |
|
@@ -37,3 +37,20 @@ async def save_image(image: UploadFile = File(...), image_path: str = Form(...))
|
|
37 |
await f.write(content) # Write the image content to a file
|
38 |
logger.info(f"Image saved successfully at {image_path}")
|
39 |
return {"message": f"Image saved successfully at {image_path}"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import os
|
5 |
import aiofiles
|
6 |
from .log_utils import build_logger
|
7 |
+
from .constants import LOG_SERVER_SUBDOAMIN, APPEND_JSON, SAVE_IMAGE, SAVE_LOG
|
8 |
|
9 |
+
logger = build_logger("log_server", "log_server.log", add_remote_handler=False)
|
10 |
|
11 |
app = APIRouter(prefix=f"/{LOG_SERVER_SUBDOAMIN}")
|
12 |
|
|
|
37 |
await f.write(content) # Write the image content to a file
|
38 |
logger.info(f"Image saved successfully at {image_path}")
|
39 |
return {"message": f"Image saved successfully at {image_path}"}
|
40 |
+
|
41 |
+
@app.post(f"/{SAVE_LOG}")
|
42 |
+
async def save_log(message: str = Form(...), log_path: str = Form(...)):
|
43 |
+
"""
|
44 |
+
Save a log message to a specified log file on the server.
|
45 |
+
"""
|
46 |
+
print(f"Received log message: {message} to be saved at: {log_path}")
|
47 |
+
# Ensure the directory for the log file exists
|
48 |
+
if os.path.dirname(log_path):
|
49 |
+
os.makedirs(os.path.dirname(log_path), exist_ok=True)
|
50 |
+
|
51 |
+
# Append the log message to the specified log file
|
52 |
+
async with aiofiles.open(log_path, mode='a') as f:
|
53 |
+
await f.write(f"{message}\n")
|
54 |
+
|
55 |
+
logger.info(f"Romote log message saved to {log_path}")
|
56 |
+
return {"message": f"Log message saved successfully to {log_path}"}
|
serve/log_utils.py
CHANGED
@@ -10,17 +10,36 @@ import platform
|
|
10 |
import sys
|
11 |
from typing import AsyncGenerator, Generator
|
12 |
import warnings
|
|
|
13 |
|
14 |
import requests
|
15 |
|
16 |
-
from .constants import LOGDIR
|
|
|
17 |
|
18 |
|
19 |
handler = None
|
20 |
visited_loggers = set()
|
21 |
|
22 |
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
global handler
|
25 |
|
26 |
formatter = logging.Formatter(
|
@@ -56,6 +75,15 @@ def build_logger(logger_name, logger_filename):
|
|
56 |
# Get logger
|
57 |
logger = logging.getLogger(logger_name)
|
58 |
logger.setLevel(logging.INFO)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
# if LOGDIR is empty, then don't try output log to local file
|
61 |
if LOGDIR != "":
|
|
|
10 |
import sys
|
11 |
from typing import AsyncGenerator, Generator
|
12 |
import warnings
|
13 |
+
from pathlib import Path
|
14 |
|
15 |
import requests
|
16 |
|
17 |
+
from .constants import LOGDIR, LOG_SERVER_ADDR, SAVE_LOG
|
18 |
+
from .utils import save_log_str_on_log_server
|
19 |
|
20 |
|
21 |
handler = None
|
22 |
visited_loggers = set()
|
23 |
|
24 |
|
25 |
+
# Assuming LOGDIR and other necessary imports and global variables are defined
|
26 |
+
|
27 |
+
class APIHandler(logging.Handler):
|
28 |
+
"""Custom logging handler that sends logs to an API."""
|
29 |
+
|
30 |
+
def __init__(self, apiUrl, log_path, *args, **kwargs):
|
31 |
+
super(APIHandler, self).__init__(*args, **kwargs)
|
32 |
+
self.apiUrl = apiUrl
|
33 |
+
self.log_path = log_path
|
34 |
+
|
35 |
+
def emit(self, record):
|
36 |
+
log_entry = self.format(record)
|
37 |
+
try:
|
38 |
+
save_log_str_on_log_server(log_entry, self.log_path)
|
39 |
+
except requests.RequestException as e:
|
40 |
+
print(f"Error sending log to API: {e}", file=sys.stderr)
|
41 |
+
|
42 |
+
def build_logger(logger_name, logger_filename, add_remote_handler=True):
|
43 |
global handler
|
44 |
|
45 |
formatter = logging.Formatter(
|
|
|
75 |
# Get logger
|
76 |
logger = logging.getLogger(logger_name)
|
77 |
logger.setLevel(logging.INFO)
|
78 |
+
|
79 |
+
if add_remote_handler:
|
80 |
+
# Add APIHandler to send logs to your API
|
81 |
+
api_url = f"{LOG_SERVER_ADDR}/{SAVE_LOG}"
|
82 |
+
|
83 |
+
remote_logger_filename = str(Path(logger_filename).stem + "_remote.log")
|
84 |
+
api_handler = APIHandler(apiUrl=api_url, log_path=f"{LOGDIR}/{remote_logger_filename}")
|
85 |
+
api_handler.setFormatter(formatter)
|
86 |
+
logger.addHandler(api_handler)
|
87 |
|
88 |
# if LOGDIR is empty, then don't try output log to local file
|
89 |
if LOGDIR != "":
|
serve/utils.py
CHANGED
@@ -6,7 +6,7 @@ import numpy as np
|
|
6 |
import gradio as gr
|
7 |
from pathlib import Path
|
8 |
from model.model_registry import *
|
9 |
-
from .constants import LOGDIR, LOG_SERVER_ADDR, APPEND_JSON, SAVE_IMAGE
|
10 |
from typing import Union
|
11 |
|
12 |
|
@@ -159,4 +159,12 @@ def append_json_item_on_log_server(json_item: Union[dict, str], log_file: str):
|
|
159 |
url = f"{LOG_SERVER_ADDR}/{APPEND_JSON}"
|
160 |
# Make the POST request, sending the JSON string and the log file name
|
161 |
response = requests.post(url, data={'json_str': json_item, 'file_name': log_file})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
return response
|
|
|
6 |
import gradio as gr
|
7 |
from pathlib import Path
|
8 |
from model.model_registry import *
|
9 |
+
from .constants import LOGDIR, LOG_SERVER_ADDR, APPEND_JSON, SAVE_IMAGE, SAVE_LOG
|
10 |
from typing import Union
|
11 |
|
12 |
|
|
|
159 |
url = f"{LOG_SERVER_ADDR}/{APPEND_JSON}"
|
160 |
# Make the POST request, sending the JSON string and the log file name
|
161 |
response = requests.post(url, data={'json_str': json_item, 'file_name': log_file})
|
162 |
+
return response
|
163 |
+
|
164 |
+
def save_log_str_on_log_server(log_str: str, log_file: str):
|
165 |
+
log_file = Path(log_file).absolute().relative_to(os.getcwd())
|
166 |
+
log_file = str(log_file)
|
167 |
+
url = f"{LOG_SERVER_ADDR}/{SAVE_LOG}"
|
168 |
+
# Make the POST request, sending the log message and the log file name
|
169 |
+
response = requests.post(url, data={'message': log_str, 'log_path': log_file})
|
170 |
return response
|