from flask import Flask, render_template, request, jsonify |
from qdrant_client import QdrantClient |
from qdrant_client import models |
import torch.nn.functional as F |
import torch |
from torch import Tensor |
from transformers import AutoTokenizer, AutoModel |
from qdrant_client.models import Batch, PointStruct |
from pickle import load, dump |
import numpy as np |
import os, time, sys |
from datetime import datetime as dt |
from datetime import timedelta |
from datetime import timezone |
from faster_whisper import WhisperModel |
import io |
app = Flask(__name__) |
beamsize = 2 |
wmodel = WhisperModel("guillaumekln/faster-whisper-small", device="cpu", compute_type="int8") |
qdrant_api_key = os.environ.get("qdrant_api_key") |
qdrant_url = os.environ.get("qdrant_url") |
qdrant_api_key = "WaGH94-bo_CzlxTNHFjBAGPvWRhbsWEKUKbMz6YQtYt4oTD1ZXTvwg" |
qdrant_url = "https://c9bee7c7-2bf3-4e1b-8838-2f6f23372ab5.us-east-1-0.aws.cloud.qdrant.io" |
client = QdrantClient(url=qdrant_url, port=443, api_key=qdrant_api_key, prefer_grpc=False) |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
def average_pool(last_hidden_states: Tensor, |
attention_mask: Tensor) -> Tensor: |
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) |
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] |
tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-base-v2') |
model = AutoModel.from_pretrained('intfloat/e5-base-v2').to(device) |
def e5embed(query): |
batch_dict = tokenizer(query, max_length=512, padding=True, truncation=True, return_tensors='pt') |
batch_dict = {k: v.to(device) for k, v in batch_dict.items()} |
outputs = model(**batch_dict) |
embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask']) |
embeddings = F.normalize(embeddings, p=2, dim=1) |
embeddings = embeddings.cpu().detach().numpy().flatten().tolist() |
return embeddings |
@app.route("/") |
def index(): |
return render_template("index.html") |
@app.route("/search", methods=["POST"]) |
def search(): |
query = request.form["query"] |
topN = 200 |
print('QUERY: ',query) |
if query.strip().startswith('tilc:'): |
collection_name = 'tils' |
qvector = "context" |
query = query.replace('tilc:', '') |
elif query.strip().startswith('til:'): |
collection_name = 'tils' |
qvector = "title" |
query = query.replace('til:', '') |
else: collection_name = 'jks' |
timh = time.time() |
sq = e5embed(query) |
print('EMBEDDING TIME: ', time.time() - timh) |
timh = time.time() |
if collection_name == "jks": results = client.search(collection_name=collection_name, query_vector=sq, with_payload=True, limit=topN) |
else: results = client.search(collection_name=collection_name, query_vector=(qvector, sq), with_payload=True, limit=100) |
print('SEARCH TIME: ', time.time() - timh) |
try: |
results = [{"text": x.payload['text'], "date": str(int(x.payload['date'])), "id": x.id} for x in results] |
return jsonify(results) |
except: |
return jsonify([]) |
@app.route("/delete_joke", methods=["POST"]) |
def delete_joke(): |
joke_id = request.form["id"] |
print('Deleting joke no', joke_id) |
client.delete(collection_name="jks", points_selector=models.PointIdsList(points=[int(joke_id)],),) |
return jsonify({"deleted": True}) |
@app.route("/whisper_transcribe", methods=["POST"]) |
def whisper_transcribe(): |
if 'audio' not in request.files: return jsonify({'error': 'No file provided'}), 400 |
audio_file = request.files['audio'] |
allowed_extensions = {'mp3', 'wav', 'ogg', 'm4v'} |
if not (audio_file and audio_file.filename.lower().split('.')[-1] in allowed_extensions): return jsonify({'error': 'Invalid file format'}), 400 |
print('Transcribing audio') |
audio_bytes = audio_file.read() |
audio_file = io.BytesIO(audio_bytes) |
segments, info = wmodel.transcribe(audio_file, beam_size=beamsize) |
text = '' |
starttime = time.time() |
for segment in segments: |
text += segment.text |
print('Time to transcribe:', time.time() - starttime, 'seconds') |
return jsonify({'transcription': text}) |
if __name__ == "__main__": |
app.run(host="", debug=True, port=7860) |