Spaces:
Running
Running
alessandro trinca tornidor
commited on
Commit
•
de3808d
1
Parent(s):
48f65b6
feat: refactor app.py using samgis-lisa
Browse files
app.py
CHANGED
@@ -1,40 +1,67 @@
|
|
1 |
import json
|
2 |
import os
|
3 |
-
import
|
4 |
from typing import Callable, NoReturn
|
5 |
|
|
|
6 |
import gradio as gr
|
7 |
import spaces
|
|
|
|
|
8 |
import uvicorn
|
|
|
9 |
from fastapi import FastAPI, HTTPException, Request, status
|
10 |
from fastapi.exceptions import RequestValidationError
|
11 |
-
from fastapi.responses import FileResponse, HTMLResponse
|
12 |
from fastapi.staticfiles import StaticFiles
|
13 |
from fastapi.templating import Jinja2Templates
|
14 |
-
from lisa_on_cuda.utils import app_helpers, frontend_builder, create_folders_and_variables_if_not_exists, session_logger
|
15 |
from pydantic import ValidationError
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
-
from samgis_lisa_on_zero import PROJECT_ROOT_FOLDER, WORKDIR, app_logger
|
18 |
-
from samgis_lisa_on_zero.utilities.constants import GRADIO_EXAMPLE_BODY, GRADIO_EXAMPLES_TEXT_LIST, GRADIO_MARKDOWN
|
19 |
-
from samgis_lisa_on_zero.utilities.type_hints import StringPromptApiRequestBody
|
20 |
|
|
|
|
|
|
|
|
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
FASTAPI_TITLE = "samgis-lisa-on-zero"
|
27 |
-
app = FastAPI(title=FASTAPI_TITLE, version="1.0")
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
@spaces.GPU
|
31 |
-
@session_logger.set_uuid_logging
|
32 |
def gpu_initialization() -> None:
|
33 |
app_logger.info("GPU initialization...")
|
34 |
|
35 |
|
36 |
def get_example_complete(example_text):
|
37 |
-
example_dict = dict(**
|
38 |
example_dict["string_prompt"] = example_text
|
39 |
return json.dumps(example_dict)
|
40 |
|
@@ -64,15 +91,14 @@ def get_gradio_interface_geojson(fn_inference: Callable):
|
|
64 |
return gradio_app
|
65 |
|
66 |
|
67 |
-
@session_logger.set_uuid_logging
|
68 |
def handle_exception_response(exception: Exception) -> NoReturn:
|
69 |
import subprocess
|
70 |
project_root_folder_content = subprocess.run(
|
71 |
-
f"ls -l {
|
72 |
)
|
73 |
app_logger.error(f"project_root folder 'ls -l' command output: {project_root_folder_content.stdout}.")
|
74 |
workdir_folder_content = subprocess.run(
|
75 |
-
f"ls -l {
|
76 |
)
|
77 |
app_logger.error(f"workdir folder 'ls -l' command stdout: {workdir_folder_content.stdout}.")
|
78 |
app_logger.error(f"workdir folder 'ls -l' command stderr: {workdir_folder_content.stderr}.")
|
@@ -83,32 +109,21 @@ def handle_exception_response(exception: Exception) -> NoReturn:
|
|
83 |
|
84 |
|
85 |
@app.get("/health")
|
86 |
-
|
87 |
-
|
88 |
-
import
|
89 |
-
from
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
lisa_on_cuda_version = importlib.metadata.version('lisa-on-cuda')
|
95 |
-
samgis_lisa_on_cuda_version = importlib.metadata.version('samgis-lisa-on-zero')
|
96 |
-
except PackageNotFoundError as pe:
|
97 |
-
app_logger.error(f"pe:{pe}.")
|
98 |
-
|
99 |
-
msg = "still alive, "
|
100 |
-
msg += f"""version:{samgis_lisa_on_cuda_version}, core version:{core_version},"""
|
101 |
-
msg += f"""lisa-on-cuda version:{lisa_on_cuda_version},"""
|
102 |
-
|
103 |
-
app_logger.info(msg)
|
104 |
return JSONResponse(status_code=200, content={"msg": "still alive..."})
|
105 |
|
106 |
|
107 |
-
@session_logger.set_uuid_logging
|
108 |
def infer_lisa_gradio(request_input: StringPromptApiRequestBody) -> str:
|
109 |
-
from
|
110 |
-
from
|
111 |
-
from
|
112 |
|
113 |
app_logger.info("starting lisa inference request...")
|
114 |
|
@@ -140,13 +155,15 @@ def infer_lisa_gradio(request_input: StringPromptApiRequestBody) -> str:
|
|
140 |
app_logger.debug(f"complete json.dumps(body):{dumped}.")
|
141 |
return dumped
|
142 |
except Exception as inference_exception:
|
143 |
-
|
|
|
|
|
144 |
except ValidationError as va1:
|
145 |
app_logger.error(f"validation error: {str(va1)}.")
|
146 |
-
|
|
|
147 |
|
148 |
|
149 |
-
@session_logger.set_uuid_logging
|
150 |
@app.post("/infer_lisa")
|
151 |
def infer_lisa(request_input: StringPromptApiRequestBody) -> JSONResponse:
|
152 |
dumped = infer_lisa_gradio(request_input=request_input)
|
@@ -156,72 +173,46 @@ def infer_lisa(request_input: StringPromptApiRequestBody) -> JSONResponse:
|
|
156 |
|
157 |
|
158 |
@app.exception_handler(RequestValidationError)
|
159 |
-
@session_logger.set_uuid_logging
|
160 |
def request_validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
app_logger.error(f'request header: {dict(headers)}.')
|
165 |
-
params = request.query_params.items()
|
166 |
-
app_logger.error(f'request query params: {dict(params)}.')
|
167 |
-
return JSONResponse(
|
168 |
-
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
169 |
-
content={"msg": "Error - Unprocessable Entity"}
|
170 |
-
)
|
171 |
|
172 |
|
173 |
@app.exception_handler(HTTPException)
|
174 |
-
@session_logger.set_uuid_logging
|
175 |
def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
params = request.query_params.items()
|
180 |
-
app_logger.error(f'request query params: {dict(params)}.')
|
181 |
-
return JSONResponse(
|
182 |
-
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
183 |
-
content={"msg": "Error - Internal Server Error"}
|
184 |
-
)
|
185 |
|
186 |
|
|
|
187 |
write_tmp_on_disk = os.getenv("WRITE_TMP_ON_DISK", "")
|
188 |
app_logger.info(f"write_tmp_on_disk:{write_tmp_on_disk}.")
|
189 |
if bool(write_tmp_on_disk):
|
190 |
try:
|
191 |
-
|
192 |
-
try:
|
193 |
-
pathlib.Path.unlink(path_write_tmp_on_disk, missing_ok=True)
|
194 |
-
except (IsADirectoryError, PermissionError, OSError) as err:
|
195 |
-
app_logger.error(f"{err} while removing old write_tmp_on_disk:{write_tmp_on_disk}.")
|
196 |
-
app_logger.error(f"is file?{path_write_tmp_on_disk.is_file()}.")
|
197 |
-
app_logger.error(f"is symlink?{path_write_tmp_on_disk.is_symlink()}.")
|
198 |
-
app_logger.error(f"is folder?{path_write_tmp_on_disk.is_dir()}.")
|
199 |
-
os.makedirs(write_tmp_on_disk, exist_ok=True)
|
200 |
app.mount("/vis_output", StaticFiles(directory=write_tmp_on_disk), name="vis_output")
|
201 |
-
|
202 |
-
app_logger.error(f"{runtime_error} while loading the folder write_tmp_on_disk:{write_tmp_on_disk}...")
|
203 |
-
raise runtime_error
|
204 |
-
templates = Jinja2Templates(directory=WORKDIR / "static")
|
205 |
-
|
206 |
|
207 |
-
|
208 |
-
|
209 |
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
|
|
|
|
|
|
216 |
|
217 |
-
static_dist_folder = WORKDIR / "static" / "dist"
|
218 |
frontend_builder.build_frontend(
|
219 |
-
project_root_folder=
|
220 |
-
input_css_path=
|
221 |
output_dist_folder=static_dist_folder
|
222 |
)
|
223 |
-
create_folders_and_variables_if_not_exists.folders_creation()
|
224 |
-
|
225 |
app_logger.info("build_frontend ok!")
|
226 |
|
227 |
templates = Jinja2Templates(directory="templates")
|
@@ -229,44 +220,46 @@ templates = Jinja2Templates(directory="templates")
|
|
229 |
app.mount("/static", StaticFiles(directory=static_dist_folder, html=True), name="static")
|
230 |
# important: the index() function and the app.mount MUST be at the end
|
231 |
# samgis.html
|
232 |
-
app.mount(
|
233 |
|
234 |
|
235 |
-
@app.get(
|
236 |
async def samgis() -> FileResponse:
|
237 |
-
return FileResponse(path=static_dist_folder / "samgis.html", media_type="text/html")
|
238 |
|
239 |
|
240 |
# lisa.html
|
241 |
-
app.mount(
|
242 |
|
243 |
|
244 |
-
@app.get(
|
245 |
async def lisa() -> FileResponse:
|
246 |
-
return FileResponse(path=static_dist_folder / "lisa.html", media_type="text/html")
|
247 |
|
248 |
|
249 |
# index.html (lisa.html copy)
|
250 |
-
app.mount(
|
251 |
|
252 |
|
253 |
-
@app.get(
|
254 |
async def index() -> FileResponse:
|
255 |
-
return FileResponse(path=static_dist_folder / "index.html", media_type="text/html")
|
|
|
256 |
|
|
|
|
|
|
|
|
|
|
|
257 |
|
258 |
-
|
259 |
-
|
260 |
-
app_helpers.app_logger.info(
|
261 |
-
f"gradio interface created, mounting gradio app on url {VITE_GRADIO_URL} within FastAPI...")
|
262 |
-
app = gr.mount_gradio_app(app, io, path=VITE_GRADIO_URL)
|
263 |
-
app_helpers.app_logger.info("mounted gradio app within fastapi")
|
264 |
|
265 |
|
266 |
if __name__ == '__main__':
|
267 |
try:
|
268 |
uvicorn.run(host="0.0.0.0", port=7860, app=app)
|
269 |
except Exception as ex:
|
270 |
-
app_logger.error(f"fastapi/gradio application {
|
271 |
-
print(f"fastapi/gradio application {
|
272 |
raise ex
|
|
|
1 |
import json
|
2 |
import os
|
3 |
+
from pathlib import Path
|
4 |
from typing import Callable, NoReturn
|
5 |
|
6 |
+
from asgi_correlation_id import CorrelationIdMiddleware
|
7 |
import gradio as gr
|
8 |
import spaces
|
9 |
+
from starlette.responses import JSONResponse
|
10 |
+
import structlog
|
11 |
import uvicorn
|
12 |
+
from dotenv import load_dotenv
|
13 |
from fastapi import FastAPI, HTTPException, Request, status
|
14 |
from fastapi.exceptions import RequestValidationError
|
15 |
+
from fastapi.responses import FileResponse, HTMLResponse
|
16 |
from fastapi.staticfiles import StaticFiles
|
17 |
from fastapi.templating import Jinja2Templates
|
|
|
18 |
from pydantic import ValidationError
|
19 |
+
from samgis_core.utilities import create_folders_if_not_exists
|
20 |
+
from samgis_core.utilities import frontend_builder
|
21 |
+
from samgis_core.utilities.session_logger import setup_logging
|
22 |
+
from samgis_web.utilities.constants import GRADIO_EXAMPLES_TEXT_LIST, GRADIO_MARKDOWN, GRADIO_EXAMPLE_BODY_STRING_PROMPT
|
23 |
+
from samgis_web.utilities.type_hints import StringPromptApiRequestBody
|
24 |
|
|
|
|
|
|
|
25 |
|
26 |
+
load_dotenv()
|
27 |
+
project_root_folder = Path(globals().get("__file__", "./_")).absolute().parent
|
28 |
+
workdir = Path(os.getenv("WORKDIR", project_root_folder))
|
29 |
+
model_folder = Path(project_root_folder / "machine_learning_models")
|
30 |
|
31 |
+
log_level = os.getenv("LOG_LEVEL", "INFO")
|
32 |
+
setup_logging(log_level=log_level)
|
33 |
+
app_logger = structlog.stdlib.get_logger()
|
34 |
+
app_logger.info(f"PROJECT_ROOT_FOLDER:{project_root_folder}, WORKDIR:{workdir}.")
|
|
|
|
|
35 |
|
36 |
+
folders_map = os.getenv("FOLDERS_MAP", "{}")
|
37 |
+
markdown_text = os.getenv("MARKDOWN_TEXT", "")
|
38 |
+
examples_text_list = os.getenv("EXAMPLES_TEXT_LIST", "").split("\n")
|
39 |
+
example_body = json.loads(os.getenv("EXAMPLE_BODY", "{}"))
|
40 |
+
mount_gradio_app = bool(os.getenv("MOUNT_GRADIO_APP", ""))
|
41 |
+
|
42 |
+
static_dist_folder = workdir / "static" / "dist"
|
43 |
+
input_css_path = os.getenv("INPUT_CSS_PATH", "src/input.css")
|
44 |
+
vite_gradio_url = os.getenv("VITE_GRADIO_URL", "/gradio")
|
45 |
+
vite_index_url = os.getenv("VITE_INDEX_URL", "/")
|
46 |
+
vite_samgis_url = os.getenv("VITE_SAMGIS_URL", "/samgis")
|
47 |
+
vite_lisa_url = os.getenv("VITE_LISA_URL", "/lisa")
|
48 |
+
fastapi_title = "samgis-lisa-on-zero2"
|
49 |
+
app = FastAPI(title=fastapi_title, version="1.0")
|
50 |
+
|
51 |
+
|
52 |
+
@app.middleware("http")
|
53 |
+
async def request_middleware(request, call_next):
|
54 |
+
from samgis_web.web.middlewares import logging_middleware
|
55 |
+
|
56 |
+
return await logging_middleware(request, call_next)
|
57 |
|
58 |
@spaces.GPU
|
|
|
59 |
def gpu_initialization() -> None:
|
60 |
app_logger.info("GPU initialization...")
|
61 |
|
62 |
|
63 |
def get_example_complete(example_text):
|
64 |
+
example_dict = dict(**GRADIO_EXAMPLE_BODY_STRING_PROMPT)
|
65 |
example_dict["string_prompt"] = example_text
|
66 |
return json.dumps(example_dict)
|
67 |
|
|
|
91 |
return gradio_app
|
92 |
|
93 |
|
|
|
94 |
def handle_exception_response(exception: Exception) -> NoReturn:
|
95 |
import subprocess
|
96 |
project_root_folder_content = subprocess.run(
|
97 |
+
f"ls -l {project_root_folder}/", shell=True, universal_newlines=True, stdout=subprocess.PIPE
|
98 |
)
|
99 |
app_logger.error(f"project_root folder 'ls -l' command output: {project_root_folder_content.stdout}.")
|
100 |
workdir_folder_content = subprocess.run(
|
101 |
+
f"ls -l {workdir}/", shell=True, universal_newlines=True, stdout=subprocess.PIPE
|
102 |
)
|
103 |
app_logger.error(f"workdir folder 'ls -l' command stdout: {workdir_folder_content.stdout}.")
|
104 |
app_logger.error(f"workdir folder 'ls -l' command stderr: {workdir_folder_content.stderr}.")
|
|
|
109 |
|
110 |
|
111 |
@app.get("/health")
|
112 |
+
async def health() -> JSONResponse:
|
113 |
+
from samgis_web.__version__ import __version__ as version_web
|
114 |
+
from samgis_core.__version__ import __version__ as version_core
|
115 |
+
from lisa_on_cuda.__version__ import __version__ as version_lisa_on_cuda
|
116 |
+
from samgis_lisa.__version__ import __version__ as version_samgis_lisa
|
117 |
+
|
118 |
+
app_logger.info(f"still alive, version_web:{version_web}, version_core:{version_core}.")
|
119 |
+
app_logger.info(f"still alive, version_lisa_on_cuda:{version_lisa_on_cuda}, version_samgis_lisa:{version_samgis_lisa}.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
return JSONResponse(status_code=200, content={"msg": "still alive..."})
|
121 |
|
122 |
|
|
|
123 |
def infer_lisa_gradio(request_input: StringPromptApiRequestBody) -> str:
|
124 |
+
from samgis_lisa.io_package.wrappers_helpers import get_parsed_bbox_points_with_string_prompt
|
125 |
+
from samgis_lisa.prediction_api import lisa
|
126 |
+
from samgis_lisa.utilities.constants import LISA_INFERENCE_FN
|
127 |
|
128 |
app_logger.info("starting lisa inference request...")
|
129 |
|
|
|
155 |
app_logger.debug(f"complete json.dumps(body):{dumped}.")
|
156 |
return dumped
|
157 |
except Exception as inference_exception:
|
158 |
+
app_logger.error(f"inference_exception:{inference_exception}.")
|
159 |
+
app_logger.error(f"inference_exception, request_input:{request_input}.")
|
160 |
+
raise HTTPException(status_code=500, detail="Internal Server Error")
|
161 |
except ValidationError as va1:
|
162 |
app_logger.error(f"validation error: {str(va1)}.")
|
163 |
+
app_logger.error(f"ValidationError, request_input:{request_input}.")
|
164 |
+
raise RequestValidationError("Unprocessable Entity")
|
165 |
|
166 |
|
|
|
167 |
@app.post("/infer_lisa")
|
168 |
def infer_lisa(request_input: StringPromptApiRequestBody) -> JSONResponse:
|
169 |
dumped = infer_lisa_gradio(request_input=request_input)
|
|
|
173 |
|
174 |
|
175 |
@app.exception_handler(RequestValidationError)
|
|
|
176 |
def request_validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
|
177 |
+
from samgis_web.web import exception_handlers
|
178 |
+
|
179 |
+
return exception_handlers.request_validation_exception_handler(request, exc)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
|
182 |
@app.exception_handler(HTTPException)
|
|
|
183 |
def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
184 |
+
from samgis_web.web import exception_handlers
|
185 |
+
|
186 |
+
return exception_handlers.http_exception_handler(request, exc)
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
|
188 |
|
189 |
+
create_folders_if_not_exists.folders_creation(folders_map)
|
190 |
write_tmp_on_disk = os.getenv("WRITE_TMP_ON_DISK", "")
|
191 |
app_logger.info(f"write_tmp_on_disk:{write_tmp_on_disk}.")
|
192 |
if bool(write_tmp_on_disk):
|
193 |
try:
|
194 |
+
assert Path(write_tmp_on_disk).is_dir()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
app.mount("/vis_output", StaticFiles(directory=write_tmp_on_disk), name="vis_output")
|
196 |
+
templates = Jinja2Templates(directory=str(project_root_folder / "static"))
|
|
|
|
|
|
|
|
|
197 |
|
198 |
+
@app.get("/vis_output", response_class=HTMLResponse)
|
199 |
+
def list_files(request: Request):
|
200 |
|
201 |
+
files = os.listdir(write_tmp_on_disk)
|
202 |
+
files_paths = sorted([f"{request.url._url}/{f}" for f in files])
|
203 |
+
print(files_paths)
|
204 |
+
return templates.TemplateResponse(
|
205 |
+
"list_files.html", {"request": request, "files": files_paths}
|
206 |
+
)
|
207 |
+
except (AssertionError, RuntimeError) as rerr:
|
208 |
+
app_logger.error(f"{rerr} while loading the folder write_tmp_on_disk:{write_tmp_on_disk}...")
|
209 |
+
raise rerr
|
210 |
|
|
|
211 |
frontend_builder.build_frontend(
|
212 |
+
project_root_folder=workdir,
|
213 |
+
input_css_path=input_css_path,
|
214 |
output_dist_folder=static_dist_folder
|
215 |
)
|
|
|
|
|
216 |
app_logger.info("build_frontend ok!")
|
217 |
|
218 |
templates = Jinja2Templates(directory="templates")
|
|
|
220 |
app.mount("/static", StaticFiles(directory=static_dist_folder, html=True), name="static")
|
221 |
# important: the index() function and the app.mount MUST be at the end
|
222 |
# samgis.html
|
223 |
+
app.mount(vite_samgis_url, StaticFiles(directory=static_dist_folder, html=True), name="samgis")
|
224 |
|
225 |
|
226 |
+
@app.get(vite_samgis_url)
|
227 |
async def samgis() -> FileResponse:
|
228 |
+
return FileResponse(path=str(static_dist_folder / "samgis.html"), media_type="text/html")
|
229 |
|
230 |
|
231 |
# lisa.html
|
232 |
+
app.mount(vite_lisa_url, StaticFiles(directory=static_dist_folder, html=True), name="lisa")
|
233 |
|
234 |
|
235 |
+
@app.get(vite_lisa_url)
|
236 |
async def lisa() -> FileResponse:
|
237 |
+
return FileResponse(path=str(static_dist_folder / "lisa.html"), media_type="text/html")
|
238 |
|
239 |
|
240 |
# index.html (lisa.html copy)
|
241 |
+
app.mount(vite_index_url, StaticFiles(directory=static_dist_folder, html=True), name="index")
|
242 |
|
243 |
|
244 |
+
@app.get(vite_index_url)
|
245 |
async def index() -> FileResponse:
|
246 |
+
return FileResponse(path=str(static_dist_folder / "index.html"), media_type="text/html")
|
247 |
+
|
248 |
|
249 |
+
app_logger.info(f"creating gradio interface...")
|
250 |
+
gr_interface = get_gradio_interface_geojson(infer_lisa_gradio)
|
251 |
+
app_logger.info(f"gradio interface created, mounting gradio app on url {vite_gradio_url} within FastAPI...")
|
252 |
+
app = gr.mount_gradio_app(app, gr_interface, path=vite_gradio_url)
|
253 |
+
app_logger.info("mounted gradio app within fastapi")
|
254 |
|
255 |
+
# add the CorrelationIdMiddleware AFTER the @app.middleware("http") decorated function to avoid missing request id
|
256 |
+
app.add_middleware(CorrelationIdMiddleware)
|
|
|
|
|
|
|
|
|
257 |
|
258 |
|
259 |
if __name__ == '__main__':
|
260 |
try:
|
261 |
uvicorn.run(host="0.0.0.0", port=7860, app=app)
|
262 |
except Exception as ex:
|
263 |
+
app_logger.error(f"fastapi/gradio application {fastapi_title}, exception:{ex}.")
|
264 |
+
print(f"fastapi/gradio application {fastapi_title}, exception:{ex}.")
|
265 |
raise ex
|