File size: 6,354 Bytes
6aa994f
 
 
 
 
 
 
 
 
 
 
0768472
6aa994f
 
6c7ce7e
 
6aa994f
 
 
6c7ce7e
 
 
 
 
6aa994f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c476445
 
 
 
 
6aa994f
 
 
 
 
 
 
da50bf2
6aa994f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0768472
b6625a4
da50bf2
b6625a4
 
 
 
 
da50bf2
 
b6625a4
 
c476445
b6625a4
 
 
bd9071b
da50bf2
 
6aa994f
c476445
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6aa994f
 
 
da50bf2
 
 
6aa994f
7e9fae4
6c7ce7e
 
 
 
 
da50bf2
6c7ce7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e9fae4
e27236f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from flask import Flask, render_template, request, jsonify
from qdrant_client import QdrantClient
from qdrant_client import models
import torch.nn.functional as F
import torch
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
from qdrant_client.models import Batch, PointStruct
from pickle import load, dump
import numpy as np
import os, time, sys
from datetime import datetime as dt
from datetime import timedelta
from datetime import timezone
from faster_whisper import WhisperModel
import io

app = Flask(__name__)

# Faster Whisper setup
# model_size = 'small'
beamsize = 2
wmodel = WhisperModel("guillaumekln/faster-whisper-small", device="cpu", compute_type="int8")

# Initialize Qdrant Client and other required settings
qdrant_api_key = os.environ.get("qdrant_api_key")
qdrant_url = os.environ.get("qdrant_url")

client = QdrantClient(url=qdrant_url, port=443, api_key=qdrant_api_key, prefer_grpc=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def average_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-base-v2')
model = AutoModel.from_pretrained('intfloat/e5-base-v2').to(device)

def e5embed(query):
  batch_dict = tokenizer(query, max_length=512, padding=True, truncation=True, return_tensors='pt')
  batch_dict = {k: v.to(device) for k, v in batch_dict.items()}
  outputs = model(**batch_dict)
  embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
  embeddings = F.normalize(embeddings, p=2, dim=1)
  embeddings = embeddings.cpu().detach().numpy().flatten().tolist()
  return embeddings

def get_id(collection):
  resp = client.scroll(collection_name=collection, limit=10000, with_payload=True, with_vectors=False,)
  max_id = max([r.id for r in resp[0]])+1
  return int(max_id)

@app.route("/")
def index():
    return render_template("index.html")

@app.route("/search", methods=["POST"])
def search():
    query = request.form["query"]
    collection_name = request.form["collection"]
    topN = 200  # Define your topN value


    print('QUERY: ',query)
    if query.strip().startswith('tilc:'):
        collection_name = 'tils'
        qvector = "context"
        query = query.replace('tilc:', '')
    elif query.strip().startswith('til:'):
        collection_name = 'tils'
        qvector = "title"
        query = query.replace('til:', '')
    else: collection_name = 'jks'

    timh = time.time()
    sq = e5embed(query)    
    print('EMBEDDING TIME: ', time.time() - timh)

    timh = time.time()
    if collection_name == "jks": results = client.search(collection_name=collection_name, query_vector=sq, with_payload=True, limit=topN)
    else: results = client.search(collection_name=collection_name, query_vector=(qvector, sq), with_payload=True, limit=100)
    print('SEARCH TIME: ', time.time() - timh)
    
    #print(results[0])
    # try: 
    new_results = []
    if collection_name == 'jks': 
        for r in results:
            if 'date' not in r.payload: r.payload['date'] = '20200101'
            new_results.append({"text": r.payload['text'], "date": str(int(r.payload['date'])), "id": r.id})  # Implement your Qdrant search here     
    else:
        for r in results:
            if 'context' in r.payload and r.payload['context'] != '': 
                if 'date' not in r.payload: r.payload['date'] = '20200101'
                new_results.append({"text": r.payload['title'] + '<br>Context: ' + r.payload['context'], "url": r.payload['url'], "date": r.payload['date'], "id": r.id})
            else: 
                if 'date' not in r.payload: r.payload['date'] = '20200101'
                new_results.append({"text": r.payload['title'], "url": r.payload['url'], "date": r.payload['date'], "id": r.id})
    return jsonify(new_results)
    # except:
    #    return jsonify([])

@app.route("/add_item", methods=["POST"])
def add_item():
    title = request.form["title"]
    url = request.form["url"]
    if url.strip() == '':
        collection_name = 'jks'
        cid = get_id(collection_name)
        print('cid', cid, time.strftime("%Y%m%d"))
        resp = client.upsert(collection_name=collection_name, points=Batch(ids=[cid], payloads=[{'text':title, 'date': time.strftime("%Y%m%d")}],vectors=[e5embed(title)]),)
    else:
      collection_name = 'tils'
      cid = get_id('tils')
      print('cid', cid, time.strftime("%Y%m%d"), collection_name)
      til = {'title': title.replace('TIL that', '').replace('TIL:', '').replace('TIL ', '').strip(), 'url': url.replace('https://', '').replace('http://', ''), "date": time.strftime("%Y%m%d_%H%M")}
      resp = client.upsert(collection_name="tils", points=[PointStruct(id=cid, payload=til, vector={"title": e5embed(til['title']),},)])
    print('Upsert response:', resp)
    return jsonify({"success": True, "index": collection_name})
    

@app.route("/delete_joke", methods=["POST"])
def delete_joke():
    joke_id = request.form["id"]
    collection_name = request.form["collection"]
    print('Deleting no.', joke_id, 'from collection', collection_name)
    client.delete(collection_name=collection_name, points_selector=models.PointIdsList(points=[int(joke_id)],),)
    return jsonify({"deleted": True})

@app.route("/whisper_transcribe", methods=["POST"])
def whisper_transcribe():
    if 'audio' not in request.files: return jsonify({'error': 'No file provided'}), 400

    audio_file = request.files['audio']
    allowed_extensions = {'mp3', 'wav', 'ogg', 'm4a'}
    if not (audio_file and audio_file.filename.lower().split('.')[-1] in allowed_extensions): return jsonify({'error': 'Invalid file format'}), 400

    print('Transcribing audio')
    audio_bytes = audio_file.read()
    audio_file = io.BytesIO(audio_bytes)

    segments, info = wmodel.transcribe(audio_file, beam_size=beamsize) # beamsize is 2.
    text = ''
    starttime = time.time()
    for segment in segments:
        text += segment.text
    print('Time to transcribe:', time.time() - starttime, 'seconds')

    return jsonify({'transcription': text})


if __name__ == "__main__":
    app.run(host="0.0.0.0", debug=True, port=7860)