node-py-test / app.py
pngwn's picture
pngwn HF staff
Update app.py
8ee459f verified
raw
history blame
4.2 kB
import httpx
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import subprocess
import signal
import time
import os
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
PYTHON_PORT = 7860
NODE_PORT = 4321
NODE_SCRIPT_PATH = "build"
node_process = subprocess.Popen(["node", NODE_SCRIPT_PATH])
def handle_sigterm(signum, frame):
print("Stopping Node.js server...")
node_process.terminate()
node_process.wait()
exit(0)
signal.signal(signal.SIGTERM, handle_sigterm)
client = httpx.AsyncClient()
@app.on_event("shutdown")
def shutdown_event():
print("Stopping Node.js server...")
node_process.terminate()
node_process.wait()
@app.get("/config")
async def route_with_config():
return JSONResponse(content={"one": "hello", "two": "from", "three": "Python"})
# async def proxy_to_node(request: Request):
# # Preserve the full path including query parameters
# full_path = request.url.path
# if request.url.query:
# full_path += f"?{request.url.query}"
# url = f"http://localhost:{NODE_PORT}{full_path}"
# headers = {
# k: v
# for k, v in request.headers.items()
# if k.lower() not in ["host", "content-length"]
# }
# print(headers)
# # body = await request.body()
# # async with client:
# # response = await client.request(
# # method=request.method, url=url, headers=headers, content=body
# # )
# # return StreamingResponse(
# # response.iter_bytes(),
# # status_code=response.status_code,
# # headers=response.headers,
# # )
# req = client.build_request("GET", httpx.URL(url), headers=headers)
# r = await client.send(req, stream=True)
# return StreamingResponse(
# r.aiter_raw(), headers=r.headers
# )
async def proxy_to_node(
request: Request,
server_name: str,
node_port: int,
python_port: int,
scheme: str = "http",
mounted_path: str = "",
):
start_time = time.time()
full_path = request.url.path
if mounted_path:
full_path = full_path.replace(mounted_path, "")
if request.url.query:
full_path += f"?{request.url.query}"
url = f"{scheme}://{server_name}:{node_port}{full_path}"
headers = dict(request.headers)
print(
headers,
)
server_url = f"{scheme}://{server_name}"
if python_port:
server_url += f":{python_port}"
if mounted_path:
server_url += mounted_path
headers["x-gradio-server"] = server_url
headers["x-gradio-port"] = str(python_port)
print(
f"Proxying request from {request.url.path} to {url} with server url {server_url}"
)
if os.getenv("GRADIO_LOCAL_DEV_MODE"):
headers["x-gradio-local-dev-mode"] = "1"
print(f"Time to prepare request: {time.time() - start_time:.4f} seconds")
print(
f"Total setup time before streaming: {time.time() - start_time:.4f} seconds"
)
req = client.build_request("GET", httpx.URL(url), headers=headers)
r = await client.send(req, stream=True)
print(f"Time to prepare request: {time.time() - start_time:.4f} seconds")
print(f"\nHeaders: {r.headers}\n")
new_headers = {
key: value
for key, value in response.headers.items()
if key.lower() != "content-length"
}
new_headers["Transfer-Encoding"] = "chunked"
return StreamingResponse(response.aiter_raw(), headers=new_headers)
@app.api_route(
"/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]
)
async def catch_all(request: Request, path: str):
return await proxy_to_node(
request,
"0.0.0.0",
4321,
request.url.port,
request.url.scheme,
"",
)
if __name__ == "__main__":
print(
f"Starting dual server. Python handles specific routes, Node handles the rest."
)
uvicorn.run(app, host="0.0.0.0", port=PYTHON_PORT)