alessandro trinca tornidor commited on
Commit
de3808d
1 Parent(s): 48f65b6

feat: refactor app.py using samgis-lisa

Browse files
Files changed (1) hide show
  1. app.py +100 -107
app.py CHANGED
@@ -1,40 +1,67 @@
1
  import json
2
  import os
3
- import pathlib
4
  from typing import Callable, NoReturn
5
 
 
6
  import gradio as gr
7
  import spaces
 
 
8
  import uvicorn
 
9
  from fastapi import FastAPI, HTTPException, Request, status
10
  from fastapi.exceptions import RequestValidationError
11
- from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
12
  from fastapi.staticfiles import StaticFiles
13
  from fastapi.templating import Jinja2Templates
14
- from lisa_on_cuda.utils import app_helpers, frontend_builder, create_folders_and_variables_if_not_exists, session_logger
15
  from pydantic import ValidationError
 
 
 
 
 
16
 
17
- from samgis_lisa_on_zero import PROJECT_ROOT_FOLDER, WORKDIR, app_logger
18
- from samgis_lisa_on_zero.utilities.constants import GRADIO_EXAMPLE_BODY, GRADIO_EXAMPLES_TEXT_LIST, GRADIO_MARKDOWN
19
- from samgis_lisa_on_zero.utilities.type_hints import StringPromptApiRequestBody
20
 
 
 
 
 
21
 
22
- VITE_INDEX_URL = os.getenv("VITE_INDEX_URL", "/")
23
- VITE_SAMGIS_URL = os.getenv("VITE_SAMGIS_URL", "/samgis")
24
- VITE_LISA_URL = os.getenv("VITE_LISA_URL", "/lisa")
25
- VITE_GRADIO_URL = os.getenv("VITE_GRADIO_URL", "/gradio")
26
- FASTAPI_TITLE = "samgis-lisa-on-zero"
27
- app = FastAPI(title=FASTAPI_TITLE, version="1.0")
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  @spaces.GPU
31
- @session_logger.set_uuid_logging
32
  def gpu_initialization() -> None:
33
  app_logger.info("GPU initialization...")
34
 
35
 
36
  def get_example_complete(example_text):
37
- example_dict = dict(**GRADIO_EXAMPLE_BODY)
38
  example_dict["string_prompt"] = example_text
39
  return json.dumps(example_dict)
40
 
@@ -64,15 +91,14 @@ def get_gradio_interface_geojson(fn_inference: Callable):
64
  return gradio_app
65
 
66
 
67
- @session_logger.set_uuid_logging
68
  def handle_exception_response(exception: Exception) -> NoReturn:
69
  import subprocess
70
  project_root_folder_content = subprocess.run(
71
- f"ls -l {PROJECT_ROOT_FOLDER}/", shell=True, universal_newlines=True, stdout=subprocess.PIPE
72
  )
73
  app_logger.error(f"project_root folder 'ls -l' command output: {project_root_folder_content.stdout}.")
74
  workdir_folder_content = subprocess.run(
75
- f"ls -l {WORKDIR}/", shell=True, universal_newlines=True, stdout=subprocess.PIPE
76
  )
77
  app_logger.error(f"workdir folder 'ls -l' command stdout: {workdir_folder_content.stdout}.")
78
  app_logger.error(f"workdir folder 'ls -l' command stderr: {workdir_folder_content.stderr}.")
@@ -83,32 +109,21 @@ def handle_exception_response(exception: Exception) -> NoReturn:
83
 
84
 
85
  @app.get("/health")
86
- @session_logger.set_uuid_logging
87
- def health() -> JSONResponse:
88
- import importlib.metadata
89
- from importlib.metadata import PackageNotFoundError
90
-
91
- core_version = lisa_on_cuda_version = samgis_lisa_on_cuda_version = ""
92
- try:
93
- core_version = importlib.metadata.version('samgis_core')
94
- lisa_on_cuda_version = importlib.metadata.version('lisa-on-cuda')
95
- samgis_lisa_on_cuda_version = importlib.metadata.version('samgis-lisa-on-zero')
96
- except PackageNotFoundError as pe:
97
- app_logger.error(f"pe:{pe}.")
98
-
99
- msg = "still alive, "
100
- msg += f"""version:{samgis_lisa_on_cuda_version}, core version:{core_version},"""
101
- msg += f"""lisa-on-cuda version:{lisa_on_cuda_version},"""
102
-
103
- app_logger.info(msg)
104
  return JSONResponse(status_code=200, content={"msg": "still alive..."})
105
 
106
 
107
- @session_logger.set_uuid_logging
108
  def infer_lisa_gradio(request_input: StringPromptApiRequestBody) -> str:
109
- from samgis_lisa_on_zero.io_package.wrappers_helpers import get_parsed_bbox_points_with_string_prompt
110
- from samgis_lisa_on_zero.prediction_api import lisa
111
- from samgis_lisa_on_zero.utilities.constants import LISA_INFERENCE_FN
112
 
113
  app_logger.info("starting lisa inference request...")
114
 
@@ -140,13 +155,15 @@ def infer_lisa_gradio(request_input: StringPromptApiRequestBody) -> str:
140
  app_logger.debug(f"complete json.dumps(body):{dumped}.")
141
  return dumped
142
  except Exception as inference_exception:
143
- handle_exception_response(inference_exception)
 
 
144
  except ValidationError as va1:
145
  app_logger.error(f"validation error: {str(va1)}.")
146
- raise ValidationError("Unprocessable Entity")
 
147
 
148
 
149
- @session_logger.set_uuid_logging
150
  @app.post("/infer_lisa")
151
  def infer_lisa(request_input: StringPromptApiRequestBody) -> JSONResponse:
152
  dumped = infer_lisa_gradio(request_input=request_input)
@@ -156,72 +173,46 @@ def infer_lisa(request_input: StringPromptApiRequestBody) -> JSONResponse:
156
 
157
 
158
  @app.exception_handler(RequestValidationError)
159
- @session_logger.set_uuid_logging
160
  def request_validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
161
- app_logger.error(f"exception errors: {exc.errors()}.")
162
- app_logger.error(f"exception body: {exc.body}.")
163
- headers = request.headers.items()
164
- app_logger.error(f'request header: {dict(headers)}.')
165
- params = request.query_params.items()
166
- app_logger.error(f'request query params: {dict(params)}.')
167
- return JSONResponse(
168
- status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
169
- content={"msg": "Error - Unprocessable Entity"}
170
- )
171
 
172
 
173
  @app.exception_handler(HTTPException)
174
- @session_logger.set_uuid_logging
175
  def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
176
- app_logger.error(f"exception: {str(exc)}.")
177
- headers = request.headers.items()
178
- app_logger.error(f'request header: {dict(headers)}.')
179
- params = request.query_params.items()
180
- app_logger.error(f'request query params: {dict(params)}.')
181
- return JSONResponse(
182
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
183
- content={"msg": "Error - Internal Server Error"}
184
- )
185
 
186
 
 
187
  write_tmp_on_disk = os.getenv("WRITE_TMP_ON_DISK", "")
188
  app_logger.info(f"write_tmp_on_disk:{write_tmp_on_disk}.")
189
  if bool(write_tmp_on_disk):
190
  try:
191
- path_write_tmp_on_disk = pathlib.Path(write_tmp_on_disk)
192
- try:
193
- pathlib.Path.unlink(path_write_tmp_on_disk, missing_ok=True)
194
- except (IsADirectoryError, PermissionError, OSError) as err:
195
- app_logger.error(f"{err} while removing old write_tmp_on_disk:{write_tmp_on_disk}.")
196
- app_logger.error(f"is file?{path_write_tmp_on_disk.is_file()}.")
197
- app_logger.error(f"is symlink?{path_write_tmp_on_disk.is_symlink()}.")
198
- app_logger.error(f"is folder?{path_write_tmp_on_disk.is_dir()}.")
199
- os.makedirs(write_tmp_on_disk, exist_ok=True)
200
  app.mount("/vis_output", StaticFiles(directory=write_tmp_on_disk), name="vis_output")
201
- except RuntimeError as runtime_error:
202
- app_logger.error(f"{runtime_error} while loading the folder write_tmp_on_disk:{write_tmp_on_disk}...")
203
- raise runtime_error
204
- templates = Jinja2Templates(directory=WORKDIR / "static")
205
-
206
 
207
- @app.get("/vis_output", response_class=HTMLResponse)
208
- def list_files(request: Request):
209
 
210
- files = os.listdir(write_tmp_on_disk)
211
- files_paths = sorted([f"{request.url._url}/{f}" for f in files])
212
- print(files_paths)
213
- return templates.TemplateResponse(
214
- "list_files.html", {"request": request, "files": files_paths}
215
- )
 
 
 
216
 
217
- static_dist_folder = WORKDIR / "static" / "dist"
218
  frontend_builder.build_frontend(
219
- project_root_folder=frontend_builder.env_project_root_folder,
220
- input_css_path=frontend_builder.env_input_css_path,
221
  output_dist_folder=static_dist_folder
222
  )
223
- create_folders_and_variables_if_not_exists.folders_creation()
224
-
225
  app_logger.info("build_frontend ok!")
226
 
227
  templates = Jinja2Templates(directory="templates")
@@ -229,44 +220,46 @@ templates = Jinja2Templates(directory="templates")
229
  app.mount("/static", StaticFiles(directory=static_dist_folder, html=True), name="static")
230
  # important: the index() function and the app.mount MUST be at the end
231
  # samgis.html
232
- app.mount(VITE_SAMGIS_URL, StaticFiles(directory=static_dist_folder, html=True), name="samgis")
233
 
234
 
235
- @app.get(VITE_SAMGIS_URL)
236
  async def samgis() -> FileResponse:
237
- return FileResponse(path=static_dist_folder / "samgis.html", media_type="text/html")
238
 
239
 
240
  # lisa.html
241
- app.mount(VITE_LISA_URL, StaticFiles(directory=static_dist_folder, html=True), name="lisa")
242
 
243
 
244
- @app.get(VITE_LISA_URL)
245
  async def lisa() -> FileResponse:
246
- return FileResponse(path=static_dist_folder / "lisa.html", media_type="text/html")
247
 
248
 
249
  # index.html (lisa.html copy)
250
- app.mount(VITE_INDEX_URL, StaticFiles(directory=static_dist_folder, html=True), name="index")
251
 
252
 
253
- @app.get(VITE_INDEX_URL)
254
  async def index() -> FileResponse:
255
- return FileResponse(path=static_dist_folder / "index.html", media_type="text/html")
 
256
 
 
 
 
 
 
257
 
258
- app_helpers.app_logger.info(f"creating gradio interface...")
259
- io = get_gradio_interface_geojson(infer_lisa_gradio)
260
- app_helpers.app_logger.info(
261
- f"gradio interface created, mounting gradio app on url {VITE_GRADIO_URL} within FastAPI...")
262
- app = gr.mount_gradio_app(app, io, path=VITE_GRADIO_URL)
263
- app_helpers.app_logger.info("mounted gradio app within fastapi")
264
 
265
 
266
  if __name__ == '__main__':
267
  try:
268
  uvicorn.run(host="0.0.0.0", port=7860, app=app)
269
  except Exception as ex:
270
- app_logger.error(f"fastapi/gradio application {FASTAPI_TITLE}, exception:{ex}.")
271
- print(f"fastapi/gradio application {FASTAPI_TITLE}, exception:{ex}.")
272
  raise ex
 
1
  import json
2
  import os
3
+ from pathlib import Path
4
  from typing import Callable, NoReturn
5
 
6
+ from asgi_correlation_id import CorrelationIdMiddleware
7
  import gradio as gr
8
  import spaces
9
+ from starlette.responses import JSONResponse
10
+ import structlog
11
  import uvicorn
12
+ from dotenv import load_dotenv
13
  from fastapi import FastAPI, HTTPException, Request, status
14
  from fastapi.exceptions import RequestValidationError
15
+ from fastapi.responses import FileResponse, HTMLResponse
16
  from fastapi.staticfiles import StaticFiles
17
  from fastapi.templating import Jinja2Templates
 
18
  from pydantic import ValidationError
19
+ from samgis_core.utilities import create_folders_if_not_exists
20
+ from samgis_core.utilities import frontend_builder
21
+ from samgis_core.utilities.session_logger import setup_logging
22
+ from samgis_web.utilities.constants import GRADIO_EXAMPLES_TEXT_LIST, GRADIO_MARKDOWN, GRADIO_EXAMPLE_BODY_STRING_PROMPT
23
+ from samgis_web.utilities.type_hints import StringPromptApiRequestBody
24
 
 
 
 
25
 
26
+ load_dotenv()
27
+ project_root_folder = Path(globals().get("__file__", "./_")).absolute().parent
28
+ workdir = Path(os.getenv("WORKDIR", project_root_folder))
29
+ model_folder = Path(project_root_folder / "machine_learning_models")
30
 
31
+ log_level = os.getenv("LOG_LEVEL", "INFO")
32
+ setup_logging(log_level=log_level)
33
+ app_logger = structlog.stdlib.get_logger()
34
+ app_logger.info(f"PROJECT_ROOT_FOLDER:{project_root_folder}, WORKDIR:{workdir}.")
 
 
35
 
36
+ folders_map = os.getenv("FOLDERS_MAP", "{}")
37
+ markdown_text = os.getenv("MARKDOWN_TEXT", "")
38
+ examples_text_list = os.getenv("EXAMPLES_TEXT_LIST", "").split("\n")
39
+ example_body = json.loads(os.getenv("EXAMPLE_BODY", "{}"))
40
+ mount_gradio_app = bool(os.getenv("MOUNT_GRADIO_APP", ""))
41
+
42
+ static_dist_folder = workdir / "static" / "dist"
43
+ input_css_path = os.getenv("INPUT_CSS_PATH", "src/input.css")
44
+ vite_gradio_url = os.getenv("VITE_GRADIO_URL", "/gradio")
45
+ vite_index_url = os.getenv("VITE_INDEX_URL", "/")
46
+ vite_samgis_url = os.getenv("VITE_SAMGIS_URL", "/samgis")
47
+ vite_lisa_url = os.getenv("VITE_LISA_URL", "/lisa")
48
+ fastapi_title = "samgis-lisa-on-zero2"
49
+ app = FastAPI(title=fastapi_title, version="1.0")
50
+
51
+
52
+ @app.middleware("http")
53
+ async def request_middleware(request, call_next):
54
+ from samgis_web.web.middlewares import logging_middleware
55
+
56
+ return await logging_middleware(request, call_next)
57
 
58
  @spaces.GPU
 
59
  def gpu_initialization() -> None:
60
  app_logger.info("GPU initialization...")
61
 
62
 
63
  def get_example_complete(example_text):
64
+ example_dict = dict(**GRADIO_EXAMPLE_BODY_STRING_PROMPT)
65
  example_dict["string_prompt"] = example_text
66
  return json.dumps(example_dict)
67
 
 
91
  return gradio_app
92
 
93
 
 
94
  def handle_exception_response(exception: Exception) -> NoReturn:
95
  import subprocess
96
  project_root_folder_content = subprocess.run(
97
+ f"ls -l {project_root_folder}/", shell=True, universal_newlines=True, stdout=subprocess.PIPE
98
  )
99
  app_logger.error(f"project_root folder 'ls -l' command output: {project_root_folder_content.stdout}.")
100
  workdir_folder_content = subprocess.run(
101
+ f"ls -l {workdir}/", shell=True, universal_newlines=True, stdout=subprocess.PIPE
102
  )
103
  app_logger.error(f"workdir folder 'ls -l' command stdout: {workdir_folder_content.stdout}.")
104
  app_logger.error(f"workdir folder 'ls -l' command stderr: {workdir_folder_content.stderr}.")
 
109
 
110
 
111
  @app.get("/health")
112
+ async def health() -> JSONResponse:
113
+ from samgis_web.__version__ import __version__ as version_web
114
+ from samgis_core.__version__ import __version__ as version_core
115
+ from lisa_on_cuda.__version__ import __version__ as version_lisa_on_cuda
116
+ from samgis_lisa.__version__ import __version__ as version_samgis_lisa
117
+
118
+ app_logger.info(f"still alive, version_web:{version_web}, version_core:{version_core}.")
119
+ app_logger.info(f"still alive, version_lisa_on_cuda:{version_lisa_on_cuda}, version_samgis_lisa:{version_samgis_lisa}.")
 
 
 
 
 
 
 
 
 
 
120
  return JSONResponse(status_code=200, content={"msg": "still alive..."})
121
 
122
 
 
123
  def infer_lisa_gradio(request_input: StringPromptApiRequestBody) -> str:
124
+ from samgis_lisa.io_package.wrappers_helpers import get_parsed_bbox_points_with_string_prompt
125
+ from samgis_lisa.prediction_api import lisa
126
+ from samgis_lisa.utilities.constants import LISA_INFERENCE_FN
127
 
128
  app_logger.info("starting lisa inference request...")
129
 
 
155
  app_logger.debug(f"complete json.dumps(body):{dumped}.")
156
  return dumped
157
  except Exception as inference_exception:
158
+ app_logger.error(f"inference_exception:{inference_exception}.")
159
+ app_logger.error(f"inference_exception, request_input:{request_input}.")
160
+ raise HTTPException(status_code=500, detail="Internal Server Error")
161
  except ValidationError as va1:
162
  app_logger.error(f"validation error: {str(va1)}.")
163
+ app_logger.error(f"ValidationError, request_input:{request_input}.")
164
+ raise RequestValidationError("Unprocessable Entity")
165
 
166
 
 
167
  @app.post("/infer_lisa")
168
  def infer_lisa(request_input: StringPromptApiRequestBody) -> JSONResponse:
169
  dumped = infer_lisa_gradio(request_input=request_input)
 
173
 
174
 
175
  @app.exception_handler(RequestValidationError)
 
176
  def request_validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
177
+ from samgis_web.web import exception_handlers
178
+
179
+ return exception_handlers.request_validation_exception_handler(request, exc)
 
 
 
 
 
 
 
180
 
181
 
182
  @app.exception_handler(HTTPException)
 
183
  def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
184
+ from samgis_web.web import exception_handlers
185
+
186
+ return exception_handlers.http_exception_handler(request, exc)
 
 
 
 
 
 
187
 
188
 
189
+ create_folders_if_not_exists.folders_creation(folders_map)
190
  write_tmp_on_disk = os.getenv("WRITE_TMP_ON_DISK", "")
191
  app_logger.info(f"write_tmp_on_disk:{write_tmp_on_disk}.")
192
  if bool(write_tmp_on_disk):
193
  try:
194
+ assert Path(write_tmp_on_disk).is_dir()
 
 
 
 
 
 
 
 
195
  app.mount("/vis_output", StaticFiles(directory=write_tmp_on_disk), name="vis_output")
196
+ templates = Jinja2Templates(directory=str(project_root_folder / "static"))
 
 
 
 
197
 
198
+ @app.get("/vis_output", response_class=HTMLResponse)
199
+ def list_files(request: Request):
200
 
201
+ files = os.listdir(write_tmp_on_disk)
202
+ files_paths = sorted([f"{request.url._url}/{f}" for f in files])
203
+ print(files_paths)
204
+ return templates.TemplateResponse(
205
+ "list_files.html", {"request": request, "files": files_paths}
206
+ )
207
+ except (AssertionError, RuntimeError) as rerr:
208
+ app_logger.error(f"{rerr} while loading the folder write_tmp_on_disk:{write_tmp_on_disk}...")
209
+ raise rerr
210
 
 
211
  frontend_builder.build_frontend(
212
+ project_root_folder=workdir,
213
+ input_css_path=input_css_path,
214
  output_dist_folder=static_dist_folder
215
  )
 
 
216
  app_logger.info("build_frontend ok!")
217
 
218
  templates = Jinja2Templates(directory="templates")
 
220
  app.mount("/static", StaticFiles(directory=static_dist_folder, html=True), name="static")
221
  # important: the index() function and the app.mount MUST be at the end
222
  # samgis.html
223
+ app.mount(vite_samgis_url, StaticFiles(directory=static_dist_folder, html=True), name="samgis")
224
 
225
 
226
+ @app.get(vite_samgis_url)
227
  async def samgis() -> FileResponse:
228
+ return FileResponse(path=str(static_dist_folder / "samgis.html"), media_type="text/html")
229
 
230
 
231
  # lisa.html
232
+ app.mount(vite_lisa_url, StaticFiles(directory=static_dist_folder, html=True), name="lisa")
233
 
234
 
235
+ @app.get(vite_lisa_url)
236
  async def lisa() -> FileResponse:
237
+ return FileResponse(path=str(static_dist_folder / "lisa.html"), media_type="text/html")
238
 
239
 
240
  # index.html (lisa.html copy)
241
+ app.mount(vite_index_url, StaticFiles(directory=static_dist_folder, html=True), name="index")
242
 
243
 
244
+ @app.get(vite_index_url)
245
  async def index() -> FileResponse:
246
+ return FileResponse(path=str(static_dist_folder / "index.html"), media_type="text/html")
247
+
248
 
249
+ app_logger.info(f"creating gradio interface...")
250
+ gr_interface = get_gradio_interface_geojson(infer_lisa_gradio)
251
+ app_logger.info(f"gradio interface created, mounting gradio app on url {vite_gradio_url} within FastAPI...")
252
+ app = gr.mount_gradio_app(app, gr_interface, path=vite_gradio_url)
253
+ app_logger.info("mounted gradio app within fastapi")
254
 
255
+ # add the CorrelationIdMiddleware AFTER the @app.middleware("http") decorated function to avoid missing request id
256
+ app.add_middleware(CorrelationIdMiddleware)
 
 
 
 
257
 
258
 
259
  if __name__ == '__main__':
260
  try:
261
  uvicorn.run(host="0.0.0.0", port=7860, app=app)
262
  except Exception as ex:
263
+ app_logger.error(f"fastapi/gradio application {fastapi_title}, exception:{ex}.")
264
+ print(f"fastapi/gradio application {fastapi_title}, exception:{ex}.")
265
  raise ex