song-recommender / server.py
jrno's picture
Add dockerfile
f6b6982
raw
history blame
1.5 kB
from fastai.collab import load_learner
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from custom_models import DotProductBias
import asyncio
import uvicorn
import pandas as pd
import os
# FastAPI app
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Model filename
model_filename = 'model.pkl'
async def setup_learner():
learn = load_learner(model_filename)
learn.dls.device = 'cpu'
return learn
learn = None
@app.on_event("startup")
async def startup_event():
"""Setup the learner on server start"""
global learn
loop = asyncio.get_event_loop() # get event loop
tasks = [asyncio.ensure_future(setup_learner())] # assign some task
learn = (await asyncio.gather(*tasks))[0]
@app.get("/recommend/{user_id}")
async def analyze(user_id: str):
not_listened_songs = ["Revelry, Kings of Leon, 2008", "Gears, Miss May I, 2010", "Sexy Bitch, David Guetta, 2009"]
input_dataframe = pd.DataFrame({'user_id': ["440abe26940ae9d9268157222a4a3d5735d44ed8"] * len(not_listened_songs), 'entry': not_listened_songs})
test_dl = learn.dls.test_dl(input_dataframe)
predictions = learn.get_preds(dl=test_dl)
print(predictions)
#pred = learn.predict(file)
return {"result": predictions[0].numpy().tolist()}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)))