|
from flask import Flask, render_template, request, jsonify |
|
from qdrant_client import QdrantClient |
|
from qdrant_client import models |
|
import torch.nn.functional as F |
|
import torch |
|
from torch import Tensor |
|
from transformers import AutoTokenizer, AutoModel |
|
from qdrant_client.models import Batch, PointStruct |
|
from pickle import load, dump |
|
import numpy as np |
|
import os, time, sys |
|
from datetime import datetime as dt |
|
from datetime import timedelta |
|
from datetime import timezone |
|
from faster_whisper import WhisperModel |
|
import io |
|
|
|
app = Flask(__name__) |
|
|
|
|
|
|
|
beamsize = 2 |
|
wmodel = WhisperModel("guillaumekln/faster-whisper-small", device="cpu", compute_type="int8") |
|
|
|
|
|
qdrant_api_key = os.environ.get("qdrant_api_key") |
|
qdrant_url = os.environ.get("qdrant_url") |
|
|
|
qdrant_api_key = "WaGH94-bo_CzlxTNHFjBAGPvWRhbsWEKUKbMz6YQtYt4oTD1ZXTvwg" |
|
qdrant_url = "https://c9bee7c7-2bf3-4e1b-8838-2f6f23372ab5.us-east-1-0.aws.cloud.qdrant.io" |
|
|
|
client = QdrantClient(url=qdrant_url, port=443, api_key=qdrant_api_key, prefer_grpc=False) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
def average_pool(last_hidden_states: Tensor, |
|
attention_mask: Tensor) -> Tensor: |
|
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) |
|
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] |
|
|
|
tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-base-v2') |
|
model = AutoModel.from_pretrained('intfloat/e5-base-v2').to(device) |
|
|
|
def e5embed(query): |
|
batch_dict = tokenizer(query, max_length=512, padding=True, truncation=True, return_tensors='pt') |
|
batch_dict = {k: v.to(device) for k, v in batch_dict.items()} |
|
outputs = model(**batch_dict) |
|
embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask']) |
|
embeddings = F.normalize(embeddings, p=2, dim=1) |
|
embeddings = embeddings.cpu().detach().numpy().flatten().tolist() |
|
return embeddings |
|
|
|
@app.route("/") |
|
def index(): |
|
return render_template("index.html") |
|
|
|
@app.route("/search", methods=["POST"]) |
|
def search(): |
|
query = request.form["query"] |
|
topN = 200 |
|
|
|
|
|
print('QUERY: ',query) |
|
if query.strip().startswith('tilc:'): |
|
collection_name = 'tils' |
|
qvector = "context" |
|
query = query.replace('tilc:', '') |
|
elif query.strip().startswith('til:'): |
|
collection_name = 'tils' |
|
qvector = "title" |
|
query = query.replace('til:', '') |
|
else: collection_name = 'jks' |
|
|
|
timh = time.time() |
|
sq = e5embed(query) |
|
print('EMBEDDING TIME: ', time.time() - timh) |
|
|
|
timh = time.time() |
|
if collection_name == "jks": results = client.search(collection_name=collection_name, query_vector=sq, with_payload=True, limit=topN) |
|
else: results = client.search(collection_name=collection_name, query_vector=(qvector, sq), with_payload=True, limit=100) |
|
print('SEARCH TIME: ', time.time() - timh) |
|
|
|
|
|
try: |
|
results = [{"text": x.payload['text'], "date": str(int(x.payload['date'])), "id": x.id} for x in results] |
|
return jsonify(results) |
|
except: |
|
return jsonify([]) |
|
|
|
@app.route("/delete_joke", methods=["POST"]) |
|
def delete_joke(): |
|
joke_id = request.form["id"] |
|
print('Deleting joke no', joke_id) |
|
client.delete(collection_name="jks", points_selector=models.PointIdsList(points=[int(joke_id)],),) |
|
return jsonify({"deleted": True}) |
|
|
|
@app.route("/whisper_transcribe", methods=["POST"]) |
|
def whisper_transcribe(): |
|
if 'audio' not in request.files: return jsonify({'error': 'No file provided'}), 400 |
|
|
|
audio_file = request.files['audio'] |
|
allowed_extensions = {'mp3', 'wav', 'ogg', 'm4v'} |
|
if not (audio_file and audio_file.filename.lower().split('.')[-1] in allowed_extensions): return jsonify({'error': 'Invalid file format'}), 400 |
|
|
|
print('Transcribing audio') |
|
audio_bytes = audio_file.read() |
|
audio_file = io.BytesIO(audio_bytes) |
|
|
|
segments, info = wmodel.transcribe(audio_file, beam_size=beamsize) |
|
text = '' |
|
starttime = time.time() |
|
for segment in segments: |
|
text += segment.text |
|
print('Time to transcribe:', time.time() - starttime, 'seconds') |
|
|
|
return jsonify({'transcription': text}) |
|
|
|
|
|
if __name__ == "__main__": |
|
app.run(host="0.0.0.0", debug=True, port=7860) |