|
import uvicorn
|
|
from fastapi import FastAPI, Depends
|
|
from starlette.responses import RedirectResponse
|
|
from starlette.middleware.sessions import SessionMiddleware
|
|
from authlib.integrations.starlette_client import OAuth, OAuthError
|
|
from fastapi import Request
|
|
import os
|
|
from starlette.config import Config
|
|
import gradio as gr
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
GOOGLE_CLIENT_ID = "687634210515-r86aagm0mse8qeq2fhi0t6atd0o1gln5.apps.googleusercontent.com"
|
|
GOOGLE_CLIENT_SECRET = "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"
|
|
SECRET_KEY = "3fc4jb0ohuBuFGWxohXspsCuxXaF"
|
|
|
|
|
|
config_data = {'GOOGLE_CLIENT_ID': GOOGLE_CLIENT_ID, 'GOOGLE_CLIENT_SECRET': GOOGLE_CLIENT_SECRET}
|
|
starlette_config = Config(environ=config_data)
|
|
oauth = OAuth(starlette_config)
|
|
oauth.register(
|
|
name='google',
|
|
server_metadata_url='https://accounts.google.com/.well-known/openid-configuration',
|
|
client_kwargs={'scope': 'openid email profile'},
|
|
)
|
|
|
|
app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY)
|
|
|
|
|
|
def get_user(request: Request):
|
|
user = request.session.get('user')
|
|
if user:
|
|
return user['name']
|
|
return None
|
|
|
|
@app.get('/')
|
|
def public(request: Request, user = Depends(get_user)):
|
|
root_url = gr.route_utils.get_root_url(request, "/", None)
|
|
if user:
|
|
return RedirectResponse(url=f'{root_url}/gradio/')
|
|
else:
|
|
return RedirectResponse(url=f'{root_url}/main/')
|
|
|
|
@app.route('/logout')
|
|
async def logout(request: Request):
|
|
request.session.pop('user', None)
|
|
return RedirectResponse(url='/')
|
|
|
|
@app.route('/login')
|
|
async def login(request: Request):
|
|
root_url = gr.route_utils.get_root_url(request, "/login", None)
|
|
|
|
redirect_uri = "https://www.google.com"
|
|
|
|
print("Redirecting to", redirect_uri)
|
|
return await oauth.google.authorize_redirect(request, redirect_uri)
|
|
|
|
@app.route('/auth')
|
|
async def auth(request: Request):
|
|
try:
|
|
access_token = await oauth.google.authorize_access_token(request)
|
|
except OAuthError:
|
|
print("Error getting access token", str(OAuthError))
|
|
return RedirectResponse(url='/')
|
|
request.session['user'] = dict(access_token)["userinfo"]
|
|
print("Redirecting to /gradio")
|
|
return RedirectResponse(url='/gradio')
|
|
|
|
with gr.Blocks() as login_demo:
|
|
btn = gr.Button("Login with Google")
|
|
_js_redirect = """
|
|
() => {
|
|
url = '/login' + window.location.search;
|
|
window.open(url);
|
|
}
|
|
"""
|
|
btn.click(None, js=_js_redirect)
|
|
|
|
app = gr.mount_gradio_app(app, login_demo, path="/main")
|
|
|
|
def greet(request: gr.Request):
|
|
return f"Welcome to Gradio, {request.username}"
|
|
|
|
with gr.Blocks() as main_demo:
|
|
m = gr.Markdown("Welcome to Gradio!")
|
|
gr.Button("Logout", link="/logout")
|
|
main_demo.load(greet, None, m)
|
|
|
|
app = gr.mount_gradio_app(app, main_demo, path="/gradio", auth_dependency=get_user)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
uvicorn.run(app) |