import httpx from fastapi import FastAPI, Request from fastapi.responses import Response, StreamingResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware import uvicorn import subprocess import signal import time import os app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) PYTHON_PORT = 7860 NODE_PORT = 4321 NODE_SCRIPT_PATH = "build" node_process = subprocess.Popen(["node", NODE_SCRIPT_PATH]) def handle_sigterm(signum, frame): print("Stopping Node.js server...") node_process.terminate() node_process.wait() exit(0) signal.signal(signal.SIGTERM, handle_sigterm) client = httpx.AsyncClient() @app.on_event("shutdown") def shutdown_event(): print("Stopping Node.js server...") node_process.terminate() node_process.wait() @app.get("/config") async def route_with_config(): return JSONResponse(content={"one": "hello", "two": "from", "three": "Python"}) # async def proxy_to_node(request: Request): # # Preserve the full path including query parameters # full_path = request.url.path # if request.url.query: # full_path += f"?{request.url.query}" # url = f"http://localhost:{NODE_PORT}{full_path}" # headers = { # k: v # for k, v in request.headers.items() # if k.lower() not in ["host", "content-length"] # } # print(headers) # # body = await request.body() # # async with client: # # response = await client.request( # # method=request.method, url=url, headers=headers, content=body # # ) # # return StreamingResponse( # # response.iter_bytes(), # # status_code=response.status_code, # # headers=response.headers, # # ) # req = client.build_request("GET", httpx.URL(url), headers=headers) # r = await client.send(req, stream=True) # return StreamingResponse( # r.aiter_raw(), headers=r.headers # ) async def proxy_to_node( request: Request, server_name: str, node_port: int, python_port: int, scheme: str = "http", mounted_path: str = "", ): start_time = time.time() full_path = request.url.path if mounted_path: full_path = full_path.replace(mounted_path, "") if request.url.query: full_path += f"?{request.url.query}" url = f"{scheme}://{server_name}:{node_port}{full_path}" headers = dict(request.headers) print( headers, ) server_url = f"{scheme}://{server_name}" if python_port: server_url += f":{python_port}" if mounted_path: server_url += mounted_path headers["x-gradio-server"] = server_url headers["x-gradio-port"] = str(python_port) print( f"Proxying request from {request.url.path} to {url} with server url {server_url}" ) if os.getenv("GRADIO_LOCAL_DEV_MODE"): headers["x-gradio-local-dev-mode"] = "1" print(f"Time to prepare request: {time.time() - start_time:.4f} seconds") print( f"Total setup time before streaming: {time.time() - start_time:.4f} seconds" ) req = client.build_request("GET", httpx.URL(url), headers=headers) r = await client.send(req) print(f"Time to prepare request: {time.time() - start_time:.4f} seconds") print(f"\nHeaders: {r.headers}\n") # new_headers = { # key: value # for key, value in r.headers.items() # if key.lower() != "content-length" # } # new_headers["Transfer-Encoding"] = "chunked" body = await r.body() return Response(body, headers=r.headers) @app.api_route( "/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"] ) async def catch_all(request: Request, path: str): return await proxy_to_node( request, "0.0.0.0", 4321, request.url.port, request.url.scheme, "", ) if __name__ == "__main__": print( f"Starting dual server. Python handles specific routes, Node handles the rest." ) uvicorn.run(app, host="0.0.0.0", port=PYTHON_PORT)