File size: 4,202 Bytes
d2aab3f
 
fb032b1
d2aab3f
 
 
 
517c9f7
821b8d3
d2aab3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63c8cd1
d2aab3f
 
 
 
 
 
 
 
 
 
 
 
 
b5e785c
63c8cd1
d2aab3f
b5e785c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de827dc
b5e785c
 
d2aab3f
b5e785c
 
d2aab3f
 
 
b5e785c
d2aab3f
d0d3b45
 
a4618d8
b5e785c
 
 
 
 
 
 
 
 
 
 
 
d2aab3f
b5e785c
 
 
63c8cd1
b5e785c
 
63c8cd1
b5e785c
d2aab3f
b5e785c
 
d2aab3f
 
1161970
fb032b1
b5e785c
 
d0d3b45
 
fb032b1
 
 
 
 
8ee459f
0b5886f
029b549
8ee459f
029b549
d2aab3f
63c8cd1
d2aab3f
 
 
 
1161970
b5e785c
562d68b
b5e785c
 
 
 
 
d2aab3f
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import httpx
from fastapi import FastAPI, Request
from fastapi.responses import Response, StreamingResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import subprocess
import signal
import time
import os

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


PYTHON_PORT = 7860
NODE_PORT = 4321
NODE_SCRIPT_PATH = "build"

node_process = subprocess.Popen(["node", NODE_SCRIPT_PATH])


def handle_sigterm(signum, frame):
    print("Stopping Node.js server...")
    node_process.terminate()
    node_process.wait()
    exit(0)


signal.signal(signal.SIGTERM, handle_sigterm)

client = httpx.AsyncClient()

@app.on_event("shutdown")
def shutdown_event():
    print("Stopping Node.js server...")
    node_process.terminate()
    node_process.wait()


@app.get("/config")
async def route_with_config():
    return JSONResponse(content={"one": "hello", "two": "from", "three": "Python"})


# async def proxy_to_node(request: Request):
    

#     # Preserve the full path including query parameters
#     full_path = request.url.path
#     if request.url.query:
#         full_path += f"?{request.url.query}"

#     url = f"http://localhost:{NODE_PORT}{full_path}"

#     headers = {
#         k: v
#         for k, v in request.headers.items()
#         if k.lower() not in ["host", "content-length"]
#     }

#     print(headers)
#     # body = await request.body()

#     # async with client:
#     #     response = await client.request(
#     #         method=request.method, url=url, headers=headers, content=body
#     #     )

#     # return StreamingResponse(
#     #     response.iter_bytes(),
#     #     status_code=response.status_code,
#     #     headers=response.headers,
#     # )

#     req = client.build_request("GET", httpx.URL(url), headers=headers)
#     r = await client.send(req, stream=True)

#     return StreamingResponse(
#         r.aiter_raw(), headers=r.headers
#     )

async def proxy_to_node(
        request: Request,
        server_name: str,
        node_port: int,
        python_port: int,
        scheme: str = "http",
        mounted_path: str = "",
):
    start_time = time.time()

    full_path = request.url.path
    if mounted_path:
        full_path = full_path.replace(mounted_path, "")
    if request.url.query:
        full_path += f"?{request.url.query}"

    url = f"{scheme}://{server_name}:{node_port}{full_path}"

    headers = dict(request.headers)


    print(
        headers,
    )

    server_url = f"{scheme}://{server_name}"
    if python_port:
        server_url += f":{python_port}"
    if mounted_path:
        server_url += mounted_path

    headers["x-gradio-server"] = server_url
    headers["x-gradio-port"] = str(python_port)

    print(
        f"Proxying request from {request.url.path} to {url} with server url {server_url}"
    )

    if os.getenv("GRADIO_LOCAL_DEV_MODE"):
        headers["x-gradio-local-dev-mode"] = "1"

    print(f"Time to prepare request: {time.time() - start_time:.4f} seconds")

    print(
        f"Total setup time before streaming: {time.time() - start_time:.4f} seconds"
    )

    req = client.build_request("GET", httpx.URL(url), headers=headers)
    r = await client.send(req)
    print(f"Time to prepare request: {time.time() - start_time:.4f} seconds")

    print(f"\nHeaders: {r.headers}\n")

    # new_headers = {
    #     key: value
    #     for key, value in r.headers.items()
    #     if key.lower() != "content-length"
    # }

    # new_headers["Transfer-Encoding"] = "chunked"
    body = await r.body()

    return Response(body, headers=r.headers)


@app.api_route(
    "/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]
)
async def catch_all(request: Request, path: str):
    return await proxy_to_node(
        request,
        "0.0.0.0",
        4321,
        request.url.port,
        request.url.scheme,
        "",
    )


if __name__ == "__main__":
    print(
        f"Starting dual server. Python handles specific routes, Node handles the rest."
    )
    uvicorn.run(app, host="0.0.0.0", port=PYTHON_PORT)