{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "4fca2e60", "metadata": {}, "outputs": [], "source": [ "!pip -q install gradio fastapi 'fastapi-users-db-sqlalchemy<5.0.0' openai uvicorn httpx requests pydantic sqlalchemy python-dotenv asyncpg pipreqs" ] }, { "cell_type": "code", "execution_count": null, "id": "a4ffa93a", "metadata": {}, "outputs": [], "source": [ "%%writefile app/db.py\n", "from typing import AsyncGenerator\n", "\n", "from fastapi import Depends\n", "from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase\n", "from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine\n", "from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base\n", "from sqlalchemy.orm import sessionmaker\n", "from dotenv import load_dotenv\n", "import os\n", "\n", "# Get the current environment from the environment variable\n", "current_environment = os.getenv(\"APP_ENV\", \"dev\")\n", "\n", "# Load the appropriate .env file based on the current environment\n", "if current_environment == \"dev\":\n", " load_dotenv(\".env.dev\")\n", "elif current_environment == \"test\":\n", " load_dotenv(\".env.test\")\n", "elif current_environment == \"prod\":\n", " load_dotenv(\".env.prod\")\n", "else:\n", " raise ValueError(\"Invalid environment specified\")\n", "\n", "db_connection_string = os.getenv(\"DB_CONNECTION_STRING\")\n", "\n", "DATABASE_URL = db_connection_string\n", "Base: DeclarativeMeta = declarative_base()\n", "\n", " \n", "class User(SQLAlchemyBaseUserTableUUID, Base):\n", " pass\n", "\n", "\n", "engine = create_async_engine(DATABASE_URL)\n", "async_session_maker = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)\n", "\n", "\n", "async def create_db_and_tables():\n", " async with engine.begin() as conn:\n", " await conn.run_sync(Base.metadata.create_all)\n", "\n", "\n", "async def get_async_session() -> AsyncGenerator[AsyncSession, None]:\n", " async with async_session_maker() as session:\n", " yield session\n", "\n", "\n", "async def get_user_db(session: AsyncSession = Depends(get_async_session)):\n", " yield SQLAlchemyUserDatabase(session, User)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "d2a08335", "metadata": {}, "outputs": [], "source": [ "%%writefile app/schemas.py\n", "import uuid\n", "\n", "from fastapi_users import schemas\n", "\n", "\n", "class UserRead(schemas.BaseUser[uuid.UUID]):\n", " pass\n", "\n", "\n", "class UserCreate(schemas.BaseUserCreate):\n", " pass\n", "\n", "\n", "class UserUpdate(schemas.BaseUserUpdate):\n", " pass\n" ] }, { "cell_type": "code", "execution_count": null, "id": "9d649fcc", "metadata": {}, "outputs": [], "source": [ "%%writefile app/users.py\n", "import uuid\n", "import os\n", "from typing import Optional\n", "from fastapi import Depends, Request\n", "from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin\n", "from fastapi_users.authentication import (\n", " AuthenticationBackend,\n", " BearerTransport,\n", " JWTStrategy,\n", ")\n", "from fastapi_users.db import SQLAlchemyUserDatabase\n", "from app.db import User, get_user_db\n", "from dotenv import load_dotenv\n", "\n", "# Get the current environment from the environment variable\n", "current_environment = os.getenv(\"APP_ENV\", \"dev\")\n", "\n", "# Load the appropriate .env file based on the current environment\n", "if current_environment == \"dev\":\n", " load_dotenv(\".env.dev\")\n", "elif current_environment == \"test\":\n", " load_dotenv(\".env.test\")\n", "elif current_environment == \"prod\":\n", " load_dotenv(\".env.prod\")\n", "else:\n", " raise ValueError(\"Invalid environment specified\")\n", "\n", "SECRET = os.getenv(\"APP_SECRET\")\n", "\n", "\n", "class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):\n", " reset_password_token_secret = SECRET\n", " verification_token_secret = SECRET\n", "\n", " async def on_after_register(self, user: User, request: Optional[Request] = None):\n", " print(f\"User {user.id} has registered.\")\n", "\n", " async def on_after_forgot_password(\n", " self, user: User, token: str, request: Optional[Request] = None\n", " ):\n", " print(f\"User {user.id} has forgot their password. Reset token: {token}\")\n", "\n", " async def on_after_request_verify(\n", " self, user: User, token: str, request: Optional[Request] = None\n", " ):\n", " print(f\"Verification requested for user {user.id}. Verification token: {token}\")\n", "\n", "\n", "async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)):\n", " yield UserManager(user_db)\n", "\n", "\n", "bearer_transport = BearerTransport(tokenUrl=\"auth/jwt/login\")\n", "\n", "\n", "def get_jwt_strategy() -> JWTStrategy:\n", " return JWTStrategy(secret=SECRET, lifetime_seconds=3600)\n", "\n", "\n", "auth_backend = AuthenticationBackend(\n", " name=\"jwt\",\n", " transport=bearer_transport,\n", " get_strategy=get_jwt_strategy,\n", ")\n", "\n", "fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])\n", "\n", "current_active_user = fastapi_users.current_user(active=True)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "d2250413", "metadata": {}, "outputs": [], "source": [ "%%writefile app/app.py\n", "import httpx\n", "import os\n", "import requests\n", "import gradio as gr\n", "import openai\n", "\n", "from fastapi import Depends, FastAPI, Request\n", "from app.db import User, create_db_and_tables\n", "from app.schemas import UserCreate, UserRead, UserUpdate\n", "from app.users import auth_backend, current_active_user, fastapi_users\n", "from dotenv import load_dotenv\n", "import examples as chatbot_examples\n", "\n", "# Get the current environment from the environment variable\n", "current_environment = os.getenv(\"APP_ENV\", \"dev\")\n", "\n", "# Load the appropriate .env file based on the current environment\n", "if current_environment == \"dev\":\n", " load_dotenv(\".env.dev\")\n", "elif current_environment == \"test\":\n", " load_dotenv(\".env.test\")\n", "elif current_environment == \"prod\":\n", " load_dotenv(\".env.prod\")\n", "else:\n", " raise ValueError(\"Invalid environment specified\")\n", " \n", " \n", "def api_login(email, password):\n", " port = os.getenv(\"APP_PORT\")\n", " scheme = os.getenv(\"APP_SCHEME\")\n", " host = os.getenv(\"APP_HOST\")\n", "\n", " url = f\"{scheme}://{host}:{port}/auth/jwt/login\"\n", " payload = {\n", " 'username': email,\n", " 'password': password\n", " }\n", " headers = {\n", " 'Content-Type': 'application/x-www-form-urlencoded'\n", " }\n", "\n", " response = requests.post(\n", " url,\n", " data=payload,\n", " headers=headers\n", " )\n", " \n", " if(response.status_code==200):\n", " response_json = response.json()\n", " api_key = response_json['access_token']\n", " return True, api_key\n", " else:\n", " response_json = response.json()\n", " detail = response_json['detail']\n", " return False, detail\n", " \n", "\n", "def get_api_key(email, password):\n", " successful, message = api_login(email, password)\n", " \n", " if(successful):\n", " return os.getenv(\"APP_API_BASE\"), message\n", " else:\n", " raise gr.Error(message)\n", " return \"\", \"\"\n", " \n", "# Define a function to get the AI's reply using the OpenAI API\n", "def get_ai_reply(message, model=\"gpt-3.5-turbo\", system_message=None, temperature=0, message_history=[]):\n", " # Initialize the messages list\n", " messages = []\n", " \n", " # Add the system message to the messages list\n", " if system_message is not None:\n", " messages += [{\"role\": \"system\", \"content\": system_message}]\n", "\n", " # Add the message history to the messages list\n", " if message_history is not None:\n", " messages += message_history\n", " \n", " # Add the user's message to the messages list\n", " messages += [{\"role\": \"user\", \"content\": message}]\n", " \n", " # Make an API call to the OpenAI ChatCompletion endpoint with the model and messages\n", " completion = openai.ChatCompletion.create(\n", " model=model,\n", " messages=messages,\n", " temperature=temperature\n", " )\n", " \n", " # Extract and return the AI's response from the API response\n", " return completion.choices[0].message.content.strip()\n", "\n", "# Define a function to handle the chat interaction with the AI model\n", "def chat(model, system_message, message, chatbot_messages, history_state):\n", " # Initialize chatbot_messages and history_state if they are not provided\n", " chatbot_messages = chatbot_messages or []\n", " history_state = history_state or []\n", " \n", " # Try to get the AI's reply using the get_ai_reply function\n", " try:\n", " ai_reply = get_ai_reply(message, model=model, system_message=system_message, message_history=history_state)\n", " except Exception as e:\n", " # If an error occurs, raise a Gradio error\n", " raise gr.Error(e)\n", " \n", " # Append the user's message and the AI's reply to the chatbot_messages list\n", " chatbot_messages.append((message, ai_reply))\n", " \n", " # Append the user's message and the AI's reply to the history_state list\n", " history_state.append({\"role\": \"user\", \"content\": message})\n", " history_state.append({\"role\": \"assistant\", \"content\": ai_reply})\n", " \n", " # Return None (empty out the user's message textbox), the updated chatbot_messages, and the updated history_state\n", " return None, chatbot_messages, history_state\n", "\n", "# Define a function to launch the chatbot interface using Gradio\n", "def get_chatbot_app(additional_examples=[]):\n", " # Load chatbot examples and merge with any additional examples provided\n", " examples = chatbot_examples.load_examples(additional=additional_examples)\n", " \n", " # Define a function to get the names of the examples\n", " def get_examples():\n", " return [example[\"name\"] for example in examples]\n", "\n", " # Define a function to choose an example based on the index\n", " def choose_example(index):\n", " if(index!=None):\n", " system_message = examples[index][\"system_message\"].strip()\n", " user_message = examples[index][\"message\"].strip()\n", " return system_message, user_message, [], []\n", " else:\n", " return \"\", \"\", [], []\n", "\n", " # Create the Gradio interface using the Blocks layout\n", " with gr.Blocks() as app:\n", " with gr.Tab(\"Conversation\"):\n", " with gr.Row():\n", " with gr.Column():\n", " # Create a dropdown to select examples\n", " example_dropdown = gr.Dropdown(get_examples(), label=\"Examples\", type=\"index\")\n", " # Create a button to load the selected example\n", " example_load_btn = gr.Button(value=\"Load\")\n", " # Create a textbox for the system message (prompt)\n", " system_message = gr.TextArea(label=\"System Message (Prompt)\", value=\"You are a helpful assistant.\", lines=20, max_lines=400)\n", " with gr.Column():\n", " # Create a dropdown to select the AI model\n", " model_selector = gr.Dropdown(\n", " [\"gpt-3.5-turbo\"],\n", " label=\"Model\",\n", " value=\"gpt-3.5-turbo\"\n", " )\n", " # Create a chatbot interface for the conversation\n", " chatbot = gr.Chatbot(label=\"Conversation\")\n", " # Create a textbox for the user's message\n", " message = gr.Textbox(label=\"Message\")\n", " # Create a state object to store the conversation history\n", " history_state = gr.State()\n", " # Create a button to send the user's message\n", " btn = gr.Button(value=\"Send\")\n", "\n", " # Connect the example load button to the choose_example function\n", " example_load_btn.click(choose_example, inputs=[example_dropdown], outputs=[system_message, message, chatbot, history_state])\n", " # Connect the send button to the chat function\n", " btn.click(chat, inputs=[model_selector, system_message, message, chatbot, history_state], outputs=[message, chatbot, history_state])\n", " with gr.Tab(\"Get API Key\"):\n", " email_box = gr.Textbox(label=\"Email Address\", placeholder=\"Student Email\")\n", " password_box = gr.Textbox(label=\"Password\", type=\"password\", placeholder=\"Student ID\")\n", " btn = gr.Button(value =\"Generate\")\n", " api_host_box = gr.Textbox(label=\"OpenAI API Base\", interactive=False)\n", " api_key_box = gr.Textbox(label=\"OpenAI API Key\", interactive=False)\n", " btn.click(get_api_key, inputs = [email_box, password_box], outputs = [api_host_box, api_key_box])\n", " # Return the app\n", " return app\n", "\n", "app = FastAPI()\n", "\n", "app.include_router(\n", " fastapi_users.get_auth_router(auth_backend), prefix=\"/auth/jwt\", tags=[\"auth\"]\n", ")\n", "app.include_router(\n", " fastapi_users.get_register_router(UserRead, UserCreate),\n", " prefix=\"/auth\",\n", " tags=[\"auth\"],\n", ")\n", "app.include_router(\n", " fastapi_users.get_users_router(UserRead, UserUpdate),\n", " prefix=\"/users\",\n", " tags=[\"users\"],\n", ")\n", "\n", "@app.get(\"/authenticated-route\")\n", "async def authenticated_route(user: User = Depends(current_active_user)):\n", " return {\"message\": f\"Hello {user.email}!\"}\n", "\n", "@app.post(\"/v1/chat/completions\")\n", "async def openai_api_chat_completions_passthrough(\n", " request: Request,\n", " user: User = Depends(fastapi_users.current_user()),\n", "):\n", " if not user:\n", " raise HTTPException(status_code=401, detail=\"Unauthorized\")\n", "\n", " # Get the request data and headers\n", " request_data = await request.json()\n", " request_headers = request.headers\n", " openai_api_key = os.getenv(\"OPENAI_API_KEY\")\n", " \n", " if(request_data['model']=='gpt-4' or request_data['model'] == 'gpt-4-32k'):\n", " print(\"User requested gpt-4, falling back to gpt-3.5-turbo\")\n", " request_data['model'] = 'gpt-3.5-turbo'\n", "\n", " # Forward the request to the OpenAI API\n", " response = requests.post(\n", " \"https://api.openai.com/v1/chat/completions\",\n", " json=request_data,\n", " headers={\n", " \"Content-Type\": request_headers.get(\"Content-Type\"),\n", " \"Authorization\": f\"Bearer {openai_api_key}\",\n", " },\n", " )\n", " print(response)\n", "\n", " # Return the OpenAI API response\n", " return response.json()\n", "\n", "@app.on_event(\"startup\")\n", "async def on_startup():\n", " # Not needed if you setup a migration system like Alembic\n", " await create_db_and_tables()\n", " \n", "gradio_gui = get_chatbot_app()\n", "gradio_gui.auth = api_login\n", "gradio_gui.auth_message = \"Hello\"\n", "app = gr.mount_gradio_app(app, gradio_gui, path=\"/gradio\")" ] }, { "cell_type": "code", "execution_count": null, "id": "f089dfd7", "metadata": {}, "outputs": [], "source": [ "%%writefile main.py\n", "import uvicorn\n", "\n", "if __name__ == \"__main__\":\n", " uvicorn.run(f\"app.app:app\", host=\"0.0.0.0\", port=8000, log_level=\"info\")" ] }, { "cell_type": "code", "execution_count": null, "id": "cb53f0ae", "metadata": {}, "outputs": [], "source": [ "!python -m pipreqs.pipreqs ." ] }, { "cell_type": "code", "execution_count": null, "id": "a20f7f8c", "metadata": {}, "outputs": [], "source": [ "!python main.py" ] }, { "cell_type": "code", "execution_count": null, "id": "65658ef7", "metadata": {}, "outputs": [], "source": [ "import contextlib\n", "\n", "from app.db import get_async_session, get_user_db\n", "from app.schemas import UserCreate\n", "from app.users import get_user_manager\n", "from fastapi_users.exceptions import UserAlreadyExists\n", "import csv\n", "\n", "get_async_session_context = contextlib.asynccontextmanager(get_async_session)\n", "get_user_db_context = contextlib.asynccontextmanager(get_user_db)\n", "get_user_manager_context = contextlib.asynccontextmanager(get_user_manager)\n", "\n", "\n", "async def create_user(email: str, password: str, is_superuser: bool = False):\n", " try:\n", " async with get_async_session_context() as session:\n", " async with get_user_db_context(session) as user_db:\n", " async with get_user_manager_context(user_db) as user_manager:\n", " user = await user_manager.create(\n", " UserCreate(\n", " email=email, password=password, is_superuser=is_superuser\n", " )\n", " )\n", " print(f\"User created {user}\")\n", " except UserAlreadyExists:\n", " print(f\"User {email} already exists\")\n", " \n", "with open(\"seeds.csv\", mode=\"r\") as csv_file:\n", " csv_reader = csv.reader(csv_file)\n", "\n", " for row in csv_reader:\n", " email = row[0]\n", " password = row[1]\n", "\n", " await create_user(email=email, password=password)" ] }, { "cell_type": "code", "execution_count": null, "id": "9553c6e6", "metadata": {}, "outputs": [], "source": [ "!git commit -m \"adding chatbot\"" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" } }, "nbformat": 4, "nbformat_minor": 5 }