Quentin Lhoest
initial commit
4f83ec0
raw
history blame
No virus
1.63 kB
import logging
from typing import Annotated
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, StringConstraints
from outlines import generate
from generate import model, sampler, stream_file
logger = logging.getLogger(__name__)
class Status(BaseModel):
status: Annotated[str, StringConstraints(pattern="ok")]
status_generator = generate.json(model, Status, sampler=sampler)
status_stream = status_generator.stream("status:")
status = "".join(char.strip() for char in status_stream if char.strip())
logger.warning("Model status: " + status)
async def stream_response(filename: str, prompt: str, columns: list[str], seed: int, size: int):
for chunk in stream_file(
filename=filename,
prompt=prompt,
columns=columns,
seed=seed,
size=size,
):
yield chunk
async def dummy_stream():
yield ""
app = FastAPI()
@app.head("/{filename}.jsonl")
@app.get("/{filename}.jsonl")
async def read_item(request: Request, filename: str, prompt: str = "", columns: str = "", seed: int = 42, size: int = 3):
if request.method == 'GET':
columns = [field.strip() for field in columns.split(",") if field.strip()]
content = stream_response(
filename,
prompt=prompt,
columns=columns,
seed=seed,
size=size
)
else:
content = dummy_stream()
response = StreamingResponse(content, media_type="text/jsonlines")
response.headers["Content-Disposition"] = f"attachment; filename={filename}.jsonl"
return response