node-py-test / app.py
pngwn's picture
pngwn HF staff
Update app.py
ec0346c verified
raw
history blame
4.07 kB
import httpx
from fastapi import FastAPI, Request
from fastapi.responses import Response, StreamingResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import subprocess
import signal
import time
import os
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
PYTHON_PORT = 7860
NODE_PORT = 4321
NODE_SCRIPT_PATH = "build"
node_process = subprocess.Popen(["node", NODE_SCRIPT_PATH])
def handle_sigterm(signum, frame):
print("Stopping Node.js server...")
node_process.terminate()
node_process.wait()
exit(0)
signal.signal(signal.SIGTERM, handle_sigterm)
client = httpx.AsyncClient()
@app.on_event("shutdown")
def shutdown_event():
print("Stopping Node.js server...")
node_process.terminate()
node_process.wait()
@app.get("/config")
async def route_with_config():
return JSONResponse(content={"one": "hello", "two": "from", "three": "Python"})
# async def proxy_to_node(request: Request):
# # Preserve the full path including query parameters
# full_path = request.url.path
# if request.url.query:
# full_path += f"?{request.url.query}"
# url = f"http://localhost:{NODE_PORT}{full_path}"
# headers = {
# k: v
# for k, v in request.headers.items()
# if k.lower() not in ["host", "content-length"]
# }
# print(headers)
# # body = await request.body()
# # async with client:
# # response = await client.request(
# # method=request.method, url=url, headers=headers, content=body
# # )
# # return StreamingResponse(
# # response.iter_bytes(),
# # status_code=response.status_code,
# # headers=response.headers,
# # )
# req = client.build_request("GET", httpx.URL(url), headers=headers)
# r = await client.send(req, stream=True)
# return StreamingResponse(
# r.aiter_raw(), headers=r.headers
# )
async def proxy_to_node(
request: Request,
server_name: str,
node_port: int,
python_port: int,
scheme: str = "http",
mounted_path: str = "",
):
start_time = time.time()
full_path = request.url.path
if mounted_path:
full_path = full_path.replace(mounted_path, "")
if request.url.query:
full_path += f"?{request.url.query}"
url = f"{scheme}://{server_name}:{node_port}{full_path}"
headers = dict(request.headers)
print(
headers,
)
server_url = f"{scheme}://{server_name}"
if python_port:
server_url += f":{python_port}"
if mounted_path:
server_url += mounted_path
headers["x-gradio-server"] = server_url
headers["x-gradio-port"] = str(python_port)
print(
f"Proxying request from {request.url.path} to {url} with server url {server_url}"
)
if os.getenv("GRADIO_LOCAL_DEV_MODE"):
headers["x-gradio-local-dev-mode"] = "1"
print(f"Time to prepare request: {time.time() - start_time:.4f} seconds")
print(
f"Total setup time before streaming: {time.time() - start_time:.4f} seconds"
)
req = client.build_request(request.method, httpx.URL(url), headers=headers)
r = await client.send(req)
print(f"Time to prepare request: {time.time() - start_time:.4f} seconds")
print(f"\nHeaders: {r.headers}\n")
return Response(
content=r.content,
status_code=r.status_code,
headers=dict(r.headers),
)
@app.api_route(
"/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]
)
async def catch_all(request: Request, path: str):
return await proxy_to_node(
request,
"0.0.0.0",
4321,
request.url.port,
request.url.scheme,
"",
)
if __name__ == "__main__":
print(
f"Starting dual server. Python handles specific routes, Node handles the rest."
)
uvicorn.run(app, host="0.0.0.0", port=PYTHON_PORT)