StevenChen16 commited on
Commit
d2250f6
1 Parent(s): 16081bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +306 -311
app.py CHANGED
@@ -4,6 +4,9 @@ import re
4
  import uuid
5
  import tempfile
6
  import json
 
 
 
7
  from argparse import ArgumentParser
8
  from threading import Thread
9
  from queue import Queue
@@ -35,10 +38,89 @@ from langchain_community.vectorstores.faiss import FAISS
35
  from langchain_huggingface import HuggingFaceEmbeddings
36
  from tqdm import tqdm
37
  import joblib
38
-
39
  import spaces
40
 
41
- # Token streamer for generation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  class TokenStreamer(BaseStreamer):
43
  def __init__(self, skip_prompt: bool = False, timeout=None):
44
  self.skip_prompt = skip_prompt
@@ -73,19 +155,54 @@ class TokenStreamer(BaseStreamer):
73
  else:
74
  return value
75
 
76
- # File loader mapping
77
- LOADER_MAPPING = {
78
- '.pdf': PyPDFLoader,
79
- '.txt': TextLoader,
80
- '.md': UnstructuredMarkdownLoader,
81
- '.csv': CSVLoader,
82
- '.jpg': UnstructuredImageLoader,
83
- '.jpeg': UnstructuredImageLoader,
84
- '.png': UnstructuredImageLoader,
85
- '.json': JSONLoader,
86
- '.html': BSHTMLLoader,
87
- '.htm': BSHTMLLoader
88
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  def load_single_file(file_path):
91
  _, ext = os.path.splitext(file_path)
@@ -112,13 +229,13 @@ def load_files(file_paths: list):
112
  docs.extend(loaded_docs)
113
  return docs
114
 
115
- # def split_text(txt, chunk_size=200, overlap=20):
116
- # if not txt:
117
- # return None
118
 
119
- # splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=overlap)
120
- # docs = splitter.split_documents(txt)
121
- # return docs
122
 
123
  def create_embedding_model(model_file):
124
  embedding = HuggingFaceEmbeddings(model_name=model_file, model_kwargs={'trust_remote_code': True})
@@ -127,70 +244,14 @@ def create_embedding_model(model_file):
127
  def save_file_paths(store_path, file_paths):
128
  joblib.dump(file_paths, f'{store_path}/file_paths.pkl')
129
 
130
- def load_file_paths(store_path):
131
- file_paths_file = f'{store_path}/file_paths.pkl'
132
- if os.path.exists(file_paths_file):
133
- return joblib.load(file_paths_file)
134
- return None
135
-
136
- def file_paths_match(store_path, file_paths):
137
- saved_file_paths = load_file_paths(store_path)
138
- return saved_file_paths == file_paths
139
-
140
- # def create_vector_store(docs, store_file, embeddings):
141
- # vector_store = FAISS.from_documents(docs, embeddings)
142
- # vector_store.save_local(store_file)
143
- # return vector_store
144
-
145
- def load_vector_store(store_path, embeddings):
146
- if os.path.exists(store_path):
147
- vector_store = FAISS.load_local(store_path, embeddings, allow_dangerous_deserialization=True)
148
- return vector_store
149
- else:
150
- return None
151
-
152
- def split_text(txt, chunk_size=200, overlap=20):
153
- if not txt:
154
- return [] # 返回空列表而不是 None
155
-
156
- splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=overlap)
157
- docs = splitter.split_documents(txt)
158
- return docs
159
-
160
  def create_vector_store(docs, store_file, embeddings):
161
- if not docs: # 添加验证
162
  raise ValueError("No documents provided for creating vector store")
163
 
164
  vector_store = FAISS.from_documents(docs, embeddings)
165
  vector_store.save_local(store_file)
166
  return vector_store
167
 
168
- def load_or_create_store(store_path, file_paths, embeddings):
169
- try:
170
- if os.path.exists(store_path) and file_paths_match(store_path, file_paths):
171
- print("Vector database is consistent with last use, no need to rewrite")
172
- vector_store = load_vector_store(store_path, embeddings)
173
- if vector_store:
174
- return vector_store
175
-
176
- print("Rewriting database")
177
- pages = load_files(file_paths)
178
- if not pages: # 添加验证
179
- raise ValueError("No documents loaded from provided file paths")
180
-
181
- docs = split_text(pages)
182
- if not docs: # 添加验证
183
- raise ValueError("No documents created after splitting text")
184
-
185
- vector_store = create_vector_store(docs, store_path, embeddings)
186
- save_file_paths(store_path, file_paths)
187
- return vector_store
188
-
189
- except Exception as e:
190
- print(f"Error creating vector store: {str(e)}")
191
- # 可以根据需要决定是否继续抛出异常
192
- raise
193
-
194
  def query_vector_store(vector_store: FAISS, query, k=4, relevance_threshold=0.8):
195
  retriever = vector_store.as_retriever(
196
  search_type="similarity_score_threshold",
@@ -200,89 +261,169 @@ def query_vector_store(vector_store: FAISS, query, k=4, relevance_threshold=0.8)
200
  context = [doc.page_content for doc in similar_docs]
201
  return context
202
 
203
- class ModelWorker:
204
- def __init__(self, model_path, device='cuda'):
205
- self.device = device
206
- self.glm_model = AutoModel.from_pretrained(
207
- model_path,
208
- trust_remote_code=True,
209
- device=device
210
- ).to(device).eval()
211
- self.glm_tokenizer = AutoTokenizer.from_pretrained(
212
- model_path,
213
- trust_remote_code=True
214
- )
215
-
216
- @torch.inference_mode()
217
- def generate_stream(self, params):
218
- prompt = params["prompt"]
219
- temperature = float(params.get("temperature", 1.0))
220
- top_p = float(params.get("top_p", 1.0))
221
- max_new_tokens = int(params.get("max_new_tokens", 256))
222
-
223
- inputs = self.glm_tokenizer([prompt], return_tensors="pt")
224
- inputs = inputs.to(self.device)
225
- streamer = TokenStreamer(skip_prompt=True)
226
-
227
- thread = Thread(
228
- target=self.glm_model.generate,
229
- kwargs=dict(
230
- **inputs,
231
- max_new_tokens=int(max_new_tokens),
232
- temperature=float(temperature),
233
- top_p=float(top_p),
234
- streamer=streamer
235
- )
236
- )
237
- thread.start()
238
-
239
- for token_id in streamer:
240
- yield token_id
241
-
242
- @spaces.GPU
243
- def generate_stream_gate(self, params):
244
- try:
245
- for x in self.generate_stream(params):
246
- yield x
247
- except Exception as e:
248
- print("Caught Unknown Error", e)
249
- ret = "Server Error"
250
- yield ret
251
-
252
- def initialize_embedding_model_and_vector_store(Embedding_Model, store_path, file_paths):
253
  embedding_model = create_embedding_model(Embedding_Model)
254
- vector_store = load_or_create_store(store_path, file_paths, embedding_model)
255
- return vector_store, embedding_model
 
 
256
 
257
- def handle_file_upload(files):
258
- if not files:
259
- return None
260
- file_paths = [file.name for file in files]
261
- return file_paths
262
 
263
- def reinitialize_database(files, progress=gr.Progress()):
264
- global vector_store, embedding_model
265
-
266
  if not files:
267
  return "No files uploaded. Please upload files first."
268
-
269
- file_paths = [file.name for file in files]
270
 
271
- progress(0, desc="Initializing embedding model...")
272
- embedding_model = create_embedding_model(Embedding_Model)
273
 
274
- progress(0.3, desc="Loading documents...")
275
- pages = load_files(file_paths)
276
 
277
- progress(0.5, desc="Splitting text...")
278
- docs = split_text(pages)
 
 
 
 
 
 
 
 
 
 
 
 
 
279
 
280
- progress(0.7, desc="Creating vector store...")
281
- vector_store = create_vector_store(docs, store_path, embedding_model)
282
- save_file_paths(store_path, file_paths)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
- return "Database reinitialized successfully!"
 
 
 
 
 
285
 
 
 
 
 
 
286
 
287
  if __name__ == "__main__":
288
  parser = ArgumentParser()
@@ -291,7 +432,6 @@ if __name__ == "__main__":
291
  parser.add_argument("--flow-path", type=str, default="./glm-4-voice-decoder")
292
  parser.add_argument("--model-path", type=str, default="THUDM/glm-4-voice-9b")
293
  parser.add_argument("--tokenizer-path", type=str, default="THUDM/glm-4-voice-tokenizer")
294
- # parser.add_argument("--whisper_model", type=str, default="base")
295
  parser.add_argument("--share", action='store_true')
296
  args = parser.parse_args()
297
 
@@ -307,169 +447,19 @@ if __name__ == "__main__":
307
  feature_extractor = None
308
  glm_model = None
309
  glm_tokenizer = None
310
- vector_store = None
311
- embedding_model = None
312
  whisper_transcribe_model = None
313
  model_worker = None
314
 
315
- # RAG configuration
316
  Embedding_Model = 'intfloat/multilingual-e5-large-instruct'
317
- file_paths = []
318
- store_path = './data.faiss'
319
 
320
- def initialize_fn():
321
- global audio_decoder, feature_extractor, whisper_model, glm_model, glm_tokenizer
322
- global vector_store, embedding_model, whisper_transcribe_model, model_worker
323
-
324
- if audio_decoder is not None:
325
- return
326
-
327
- model_worker = ModelWorker(args.model_path, device)
328
- glm_tokenizer = model_worker.glm_tokenizer
329
-
330
- audio_decoder = AudioDecoder(
331
- config_path=flow_config,
332
- flow_ckpt_path=flow_checkpoint,
333
- hift_ckpt_path=hift_checkpoint,
334
- device=device
335
- )
336
-
337
- whisper_model = WhisperVQEncoder.from_pretrained(args.tokenizer_path).eval().to(device)
338
- feature_extractor = WhisperFeatureExtractor.from_pretrained(args.tokenizer_path)
339
-
340
- embedding_model = create_embedding_model(Embedding_Model)
341
- vector_store = load_or_create_store(store_path, file_paths, embedding_model)
342
-
343
- whisper_transcribe_model = whisper.load_model("base")
344
-
345
- def clear_fn():
346
- return [], [], '', '', '', None, None
347
-
348
- def inference_fn(
349
- temperature: float,
350
- top_p: float,
351
- max_new_token: int,
352
- input_mode,
353
- audio_path: str | None,
354
- input_text: str | None,
355
- history: list[dict],
356
- previous_input_tokens: str,
357
- previous_completion_tokens: str,
358
- ):
359
- global whisper_transcribe_model, vector_store
360
- using_context = False
361
-
362
- if input_mode == "audio":
363
- assert audio_path is not None
364
- history.append({"role": "user", "content": {"path": audio_path}})
365
- audio_tokens = extract_speech_token(
366
- whisper_model, feature_extractor, [audio_path]
367
- )[0]
368
- if len(audio_tokens) == 0:
369
- raise gr.Error("No audio tokens extracted")
370
- audio_tokens = "".join([f"<|audio_{x}|>" for x in audio_tokens])
371
- audio_tokens = "<|begin_of_audio|>" + audio_tokens + "<|end_of_audio|>"
372
- user_input = audio_tokens
373
- system_prompt = "User will provide you with a speech instruction. Do it step by step."
374
-
375
- whisper_result = whisper_transcribe_model.transcribe(audio_path)
376
- transcribed_text = whisper_result['text']
377
- context = query_vector_store(vector_store, transcribed_text, 4, 0.7)
378
- else:
379
- assert input_text is not None
380
- history.append({"role": "user", "content": input_text})
381
- user_input = input_text
382
- system_prompt = "User will provide you with a text instruction. Do it step by step."
383
- context = query_vector_store(vector_store, input_text, 4, 0.7)
384
-
385
- if context is not None:
386
- using_context = True
387
-
388
- inputs = previous_input_tokens + previous_completion_tokens
389
- inputs = inputs.strip()
390
- if "<|system|>" not in inputs:
391
- inputs += f"<|system|>\n{system_prompt}"
392
- if ("<|context|>" not in inputs) and (using_context == True):
393
- inputs += f"<|context|> According to the following content: {context}, Please answer the question"
394
- if "<|context|>" not in inputs and context is not None:
395
- inputs += f"<|context|>\n{context}"
396
- inputs += f"<|user|>\n{user_input}<|assistant|>streaming_transcription\n"
397
-
398
- with torch.no_grad():
399
- text_tokens, audio_tokens = [], []
400
- audio_offset = glm_tokenizer.convert_tokens_to_ids('<|audio_0|>')
401
- end_token_id = glm_tokenizer.convert_tokens_to_ids('<|user|>')
402
- complete_tokens = []
403
- prompt_speech_feat = torch.zeros(1, 0, 80).to(device)
404
- flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int64).to(device)
405
- this_uuid = str(uuid.uuid4())
406
- tts_speechs = []
407
- tts_mels = []
408
- prev_mel = None
409
- is_finalize = False
410
- block_size = 10
411
-
412
- # Generate tokens using ModelWorker directly instead of API
413
- for token_id in model_worker.generate_stream_gate({
414
- "prompt": inputs,
415
- "temperature": temperature,
416
- "top_p": top_p,
417
- "max_new_tokens": max_new_token,
418
- }):
419
- if isinstance(token_id, str): # Error case
420
- yield history, inputs, '', token_id, None, None
421
- return
422
-
423
- if token_id == end_token_id:
424
- is_finalize = True
425
- if len(audio_tokens) >= block_size or (is_finalize and audio_tokens):
426
- block_size = 20
427
- tts_token = torch.tensor(audio_tokens, device=device).unsqueeze(0)
428
-
429
- if prev_mel is not None:
430
- prompt_speech_feat = torch.cat(tts_mels, dim=-1).transpose(1, 2)
431
-
432
- tts_speech, tts_mel = audio_decoder.token2wav(
433
- tts_token,
434
- uuid=this_uuid,
435
- prompt_token=flow_prompt_speech_token.to(device),
436
- prompt_feat=prompt_speech_feat.to(device),
437
- finalize=is_finalize
438
- )
439
- prev_mel = tts_mel
440
-
441
- tts_speechs.append(tts_speech.squeeze())
442
- tts_mels.append(tts_mel)
443
- yield history, inputs, '', '', (22050, tts_speech.squeeze().cpu().numpy()), None
444
- flow_prompt_speech_token = torch.cat((flow_prompt_speech_token, tts_token), dim=-1)
445
- audio_tokens = []
446
-
447
- if not is_finalize:
448
- complete_tokens.append(token_id)
449
- if token_id >= audio_offset:
450
- audio_tokens.append(token_id - audio_offset)
451
- else:
452
- text_tokens.append(token_id)
453
-
454
- # Generate final audio and save
455
- tts_speech = torch.cat(tts_speechs, dim=-1).cpu()
456
- complete_text = glm_tokenizer.decode(complete_tokens, spaces_between_special_tokens=False)
457
-
458
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
459
- torchaudio.save(f, tts_speech.unsqueeze(0), 22050, format="wav")
460
-
461
- history.append({"role": "assistant", "content": {"path": f.name, "type": "audio/wav"}})
462
- history.append({"role": "assistant", "content": glm_tokenizer.decode(text_tokens, ignore_special_tokens=False)})
463
- yield history, inputs, complete_text, '', None, (22050, tts_speech.numpy())
464
-
465
- def update_input_interface(input_mode):
466
- if input_mode == "audio":
467
- return [gr.update(visible=True), gr.update(visible=False)]
468
- else:
469
- return [gr.update(visible=False), gr.update(visible=True)]
470
-
471
- # Create Gradio interface with new layout
472
  with gr.Blocks(title="GLM-4-Voice Demo", fill_height=True) as demo:
 
 
 
473
  with gr.Row():
474
  # Left column for chat interface
475
  with gr.Column(scale=2):
@@ -534,7 +524,7 @@ if __name__ == "__main__":
534
  file_count="multiple"
535
  )
536
 
537
- reinit_btn = gr.Button("Reinitialize Database", variant="secondary")
538
  status_text = gr.Textbox(label="Status", interactive=False)
539
 
540
  history_state = gr.State([])
@@ -550,6 +540,7 @@ if __name__ == "__main__":
550
  audio,
551
  text_input,
552
  history_state,
 
553
  ],
554
  outputs=[
555
  history_state,
@@ -576,12 +567,16 @@ if __name__ == "__main__":
576
  outputs=[audio, text_input]
577
  )
578
 
579
- # Database reinitialization handler
580
  reinit_btn.click(
581
  reinitialize_database,
582
- inputs=[file_upload],
583
  outputs=[status_text]
584
  )
 
 
 
 
585
 
586
  # Initialize models and launch interface
587
  initialize_fn()
 
4
  import uuid
5
  import tempfile
6
  import json
7
+ import time
8
+ import shutil
9
+ from pathlib import Path
10
  from argparse import ArgumentParser
11
  from threading import Thread
12
  from queue import Queue
 
38
  from langchain_huggingface import HuggingFaceEmbeddings
39
  from tqdm import tqdm
40
  import joblib
 
41
  import spaces
42
 
43
+ # File loader mapping
44
+ LOADER_MAPPING = {
45
+ '.pdf': PyPDFLoader,
46
+ '.txt': TextLoader,
47
+ '.md': UnstructuredMarkdownLoader,
48
+ '.csv': CSVLoader,
49
+ '.jpg': UnstructuredImageLoader,
50
+ '.jpeg': UnstructuredImageLoader,
51
+ '.png': UnstructuredImageLoader,
52
+ '.json': JSONLoader,
53
+ '.html': BSHTMLLoader,
54
+ '.htm': BSHTMLLoader
55
+ }
56
+
57
+ class SessionManager:
58
+ def __init__(self, base_path="./sessions"):
59
+ self.base_path = Path(base_path)
60
+ self.base_path.mkdir(exist_ok=True)
61
+
62
+ def create_session(self):
63
+ session_id = str(uuid.uuid4())
64
+ session_path = self.base_path / session_id
65
+ session_path.mkdir(exist_ok=True)
66
+ return session_id
67
+
68
+ def get_session_path(self, session_id):
69
+ return self.base_path / session_id
70
+
71
+ def cleanup_old_sessions(self, max_age_hours=24):
72
+ current_time = time.time()
73
+ for session_dir in self.base_path.iterdir():
74
+ if session_dir.is_dir():
75
+ dir_stats = os.stat(session_dir)
76
+ age_hours = (current_time - dir_stats.st_mtime) / 3600
77
+ if age_hours > max_age_hours:
78
+ shutil.rmtree(session_dir)
79
+
80
+ class VectorStoreManager:
81
+ def __init__(self, session_manager, embedding_model):
82
+ self.session_manager = session_manager
83
+ self.embedding_model = embedding_model
84
+ self.stores = {}
85
+
86
+ def get_store_path(self, session_id):
87
+ session_path = self.session_manager.get_session_path(session_id)
88
+ return session_path / "vector_store.faiss"
89
+
90
+ def create_store(self, session_id, files):
91
+ if not files:
92
+ return None
93
+
94
+ store_path = self.get_store_path(session_id)
95
+ file_paths = [f.name for f in files]
96
+
97
+ pages = load_files(file_paths)
98
+ if not pages:
99
+ return None
100
+
101
+ docs = split_text(pages)
102
+ if not docs:
103
+ return None
104
+
105
+ vector_store = FAISS.from_documents(docs, self.embedding_model)
106
+ vector_store.save_local(str(store_path))
107
+ save_file_paths(str(store_path.parent), file_paths)
108
+
109
+ self.stores[session_id] = vector_store
110
+ return vector_store
111
+
112
+ def get_store(self, session_id):
113
+ if session_id in self.stores:
114
+ return self.stores[session_id]
115
+
116
+ store_path = self.get_store_path(session_id)
117
+ if store_path.exists():
118
+ vector_store = FAISS.load_local(str(store_path), self.embedding_model)
119
+ self.stores[session_id] = vector_store
120
+ return vector_store
121
+
122
+ return None
123
+
124
  class TokenStreamer(BaseStreamer):
125
  def __init__(self, skip_prompt: bool = False, timeout=None):
126
  self.skip_prompt = skip_prompt
 
155
  else:
156
  return value
157
 
158
+ class ModelWorker:
159
+ def __init__(self, model_path, device='cuda'):
160
+ self.device = device
161
+ self.glm_model = AutoModel.from_pretrained(
162
+ model_path,
163
+ trust_remote_code=True,
164
+ device=device
165
+ ).to(device).eval()
166
+ self.glm_tokenizer = AutoTokenizer.from_pretrained(
167
+ model_path,
168
+ trust_remote_code=True
169
+ )
170
+
171
+ @torch.inference_mode()
172
+ def generate_stream(self, params):
173
+ prompt = params["prompt"]
174
+ temperature = float(params.get("temperature", 1.0))
175
+ top_p = float(params.get("top_p", 1.0))
176
+ max_new_tokens = int(params.get("max_new_tokens", 256))
177
+
178
+ inputs = self.glm_tokenizer([prompt], return_tensors="pt")
179
+ inputs = inputs.to(self.device)
180
+ streamer = TokenStreamer(skip_prompt=True)
181
+
182
+ thread = Thread(
183
+ target=self.glm_model.generate,
184
+ kwargs=dict(
185
+ **inputs,
186
+ max_new_tokens=int(max_new_tokens),
187
+ temperature=float(temperature),
188
+ top_p=float(top_p),
189
+ streamer=streamer
190
+ )
191
+ )
192
+ thread.start()
193
+
194
+ for token_id in streamer:
195
+ yield token_id
196
+
197
+ @spaces.GPU
198
+ def generate_stream_gate(self, params):
199
+ try:
200
+ for x in self.generate_stream(params):
201
+ yield x
202
+ except Exception as e:
203
+ print("Caught Unknown Error", e)
204
+ ret = "Server Error"
205
+ yield ret
206
 
207
  def load_single_file(file_path):
208
  _, ext = os.path.splitext(file_path)
 
229
  docs.extend(loaded_docs)
230
  return docs
231
 
232
+ def split_text(txt, chunk_size=200, overlap=20):
233
+ if not txt:
234
+ return []
235
 
236
+ splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=overlap)
237
+ docs = splitter.split_documents(txt)
238
+ return docs
239
 
240
  def create_embedding_model(model_file):
241
  embedding = HuggingFaceEmbeddings(model_name=model_file, model_kwargs={'trust_remote_code': True})
 
244
  def save_file_paths(store_path, file_paths):
245
  joblib.dump(file_paths, f'{store_path}/file_paths.pkl')
246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  def create_vector_store(docs, store_file, embeddings):
248
+ if not docs:
249
  raise ValueError("No documents provided for creating vector store")
250
 
251
  vector_store = FAISS.from_documents(docs, embeddings)
252
  vector_store.save_local(store_file)
253
  return vector_store
254
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  def query_vector_store(vector_store: FAISS, query, k=4, relevance_threshold=0.8):
256
  retriever = vector_store.as_retriever(
257
  search_type="similarity_score_threshold",
 
261
  context = [doc.page_content for doc in similar_docs]
262
  return context
263
 
264
+ def initialize_fn():
265
+ global audio_decoder, feature_extractor, whisper_model, glm_model, glm_tokenizer
266
+ global session_manager, vector_store_manager, whisper_transcribe_model, model_worker
267
+
268
+ if audio_decoder is not None:
269
+ return
270
+
271
+ model_worker = ModelWorker(args.model_path, device)
272
+ glm_tokenizer = model_worker.glm_tokenizer
273
+
274
+ audio_decoder = AudioDecoder(
275
+ config_path=flow_config,
276
+ flow_ckpt_path=flow_checkpoint,
277
+ hift_ckpt_path=hift_checkpoint,
278
+ device=device
279
+ )
280
+
281
+ whisper_model = WhisperVQEncoder.from_pretrained(args.tokenizer_path).eval().to(device)
282
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(args.tokenizer_path)
283
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  embedding_model = create_embedding_model(Embedding_Model)
285
+ session_manager = SessionManager()
286
+ vector_store_manager = VectorStoreManager(session_manager, embedding_model)
287
+
288
+ whisper_transcribe_model = whisper.load_model("base")
289
 
290
+ def clear_fn():
291
+ return [], [], '', '', '', None, None
 
 
 
292
 
293
+ def reinitialize_database(files, session_id, progress=gr.Progress()):
 
 
294
  if not files:
295
  return "No files uploaded. Please upload files first."
 
 
296
 
297
+ progress(0.5, desc="Processing documents and creating vector store...")
298
+ vector_store = vector_store_manager.create_store(session_id, files)
299
 
300
+ if vector_store is None:
301
+ return "Failed to create vector store. Please check your documents."
302
 
303
+ return "Database initialized successfully!"
304
+
305
+ def inference_fn(
306
+ temperature: float,
307
+ top_p: float,
308
+ max_new_token: int,
309
+ input_mode,
310
+ audio_path: str | None,
311
+ input_text: str | None,
312
+ history: list[dict],
313
+ session_id: str,
314
+ ):
315
+ vector_store = vector_store_manager.get_store(session_id)
316
+ using_context = False
317
+ context = None
318
 
319
+ if input_mode == "audio":
320
+ assert audio_path is not None
321
+ history.append({"role": "user", "content": {"path": audio_path}})
322
+ audio_tokens = extract_speech_token(
323
+ whisper_model, feature_extractor, [audio_path]
324
+ )[0]
325
+ if len(audio_tokens) == 0:
326
+ raise gr.Error("No audio tokens extracted")
327
+ audio_tokens = "".join([f"<|audio_{x}|>" for x in audio_tokens])
328
+ audio_tokens = "<|begin_of_audio|>" + audio_tokens + "<|end_of_audio|>"
329
+ user_input = audio_tokens
330
+ system_prompt = "User will provide you with a speech instruction. Do it step by step."
331
+
332
+ if vector_store:
333
+ whisper_result = whisper_transcribe_model.transcribe(audio_path)
334
+ transcribed_text = whisper_result['text']
335
+ context = query_vector_store(vector_store, transcribed_text, 4, 0.7)
336
+ else:
337
+ assert input_text is not None
338
+ history.append({"role": "user", "content": input_text})
339
+ user_input = input_text
340
+ system_prompt = "User will provide you with a text instruction. Do it step by step."
341
+ if vector_store:
342
+ context = query_vector_store(vector_store, input_text, 4, 0.7)
343
+
344
+ if context:
345
+ using_context = True
346
+
347
+ inputs = ""
348
+ if "<|system|>" not in inputs:
349
+ inputs += f"<|system|>\n{system_prompt}"
350
+ if ("<|context|>" not in inputs) and (using_context == True):
351
+ inputs += f"<|context|> According to the following content: {context}, Please answer the question"
352
+ if "<|context|>" not in inputs and context is not None:
353
+ inputs += f"<|context|>\n{context}"
354
+ inputs += f"<|user|>\n{user_input}<|assistant|>streaming_transcription\n"
355
+
356
+ with torch.no_grad():
357
+ text_tokens, audio_tokens = [], []
358
+ audio_offset = glm_tokenizer.convert_tokens_to_ids('<|audio_0|>')
359
+ end_token_id = glm_tokenizer.convert_tokens_to_ids('<|user|>')
360
+ complete_tokens = []
361
+ prompt_speech_feat = torch.zeros(1, 0, 80).to(device)
362
+ flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int64).to(device)
363
+ this_uuid = str(uuid.uuid4())
364
+ tts_speechs = []
365
+ tts_mels = []
366
+ prev_mel = None
367
+ is_finalize = False
368
+ block_size = 10
369
+
370
+ for token_id in model_worker.generate_stream_gate({
371
+ "prompt": inputs,
372
+ "temperature": temperature,
373
+ "top_p": top_p,
374
+ "max_new_tokens": max_new_token,
375
+ }):
376
+ if isinstance(token_id, str):
377
+ yield history, inputs, '', token_id, None, None
378
+ return
379
+
380
+ if token_id == end_token_id:
381
+ is_finalize = True
382
+ if len(audio_tokens) >= block_size or (is_finalize and audio_tokens):
383
+ block_size = 20
384
+ tts_token = torch.tensor(audio_tokens, device=device).unsqueeze(0)
385
+
386
+ if prev_mel is not None:
387
+ prompt_speech_feat = torch.cat(tts_mels, dim=-1).transpose(1, 2)
388
+
389
+ tts_speech, tts_mel = audio_decoder.token2wav(
390
+ tts_token,
391
+ uuid=this_uuid,
392
+ prompt_token=flow_prompt_speech_token.to(device),
393
+ prompt_feat=prompt_speech_feat.to(device),
394
+ finalize=is_finalize
395
+ )
396
+ prev_mel = tts_mel
397
+
398
+ tts_speechs.append(tts_speech.squeeze())
399
+ tts_mels.append(tts_mel)
400
+ yield history, inputs, '', '', (22050, tts_speech.squeeze().cpu().numpy()), None
401
+ flow_prompt_speech_token = torch.cat((flow_prompt_speech_token, tts_token), dim=-1)
402
+ audio_tokens = []
403
+
404
+ if not is_finalize:
405
+ complete_tokens.append(token_id)
406
+ if token_id >= audio_offset:
407
+ audio_tokens.append(token_id - audio_offset)
408
+ else:
409
+ text_tokens.append(token_id)
410
+
411
+ # Generate final audio and save
412
+ tts_speech = torch.cat(tts_speechs, dim=-1).cpu()
413
+ complete_text = glm_tokenizer.decode(complete_tokens, spaces_between_special_tokens=False)
414
 
415
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
416
+ torchaudio.save(f, tts_speech.unsqueeze(0), 22050, format="wav")
417
+
418
+ history.append({"role": "assistant", "content": {"path": f.name, "type": "audio/wav"}})
419
+ history.append({"role": "assistant", "content": glm_tokenizer.decode(text_tokens, ignore_special_tokens=False)})
420
+ yield history, inputs, complete_text, '', None, (22050, tts_speech.numpy())
421
 
422
+ def update_input_interface(input_mode):
423
+ if input_mode == "audio":
424
+ return [gr.update(visible=True), gr.update(visible=False)]
425
+ else:
426
+ return [gr.update(visible=False), gr.update(visible=True)]
427
 
428
  if __name__ == "__main__":
429
  parser = ArgumentParser()
 
432
  parser.add_argument("--flow-path", type=str, default="./glm-4-voice-decoder")
433
  parser.add_argument("--model-path", type=str, default="THUDM/glm-4-voice-9b")
434
  parser.add_argument("--tokenizer-path", type=str, default="THUDM/glm-4-voice-tokenizer")
 
435
  parser.add_argument("--share", action='store_true')
436
  args = parser.parse_args()
437
 
 
447
  feature_extractor = None
448
  glm_model = None
449
  glm_tokenizer = None
450
+ session_manager = None
451
+ vector_store_manager = None
452
  whisper_transcribe_model = None
453
  model_worker = None
454
 
455
+ # Configuration
456
  Embedding_Model = 'intfloat/multilingual-e5-large-instruct'
 
 
457
 
458
+ # Create Gradio interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459
  with gr.Blocks(title="GLM-4-Voice Demo", fill_height=True) as demo:
460
+ # Add session state
461
+ session_id = gr.State(lambda: session_manager.create_session())
462
+
463
  with gr.Row():
464
  # Left column for chat interface
465
  with gr.Column(scale=2):
 
524
  file_count="multiple"
525
  )
526
 
527
+ reinit_btn = gr.Button("Initialize Database", variant="secondary")
528
  status_text = gr.Textbox(label="Status", interactive=False)
529
 
530
  history_state = gr.State([])
 
540
  audio,
541
  text_input,
542
  history_state,
543
+ session_id,
544
  ],
545
  outputs=[
546
  history_state,
 
567
  outputs=[audio, text_input]
568
  )
569
 
570
+ # Database initialization handler
571
  reinit_btn.click(
572
  reinitialize_database,
573
+ inputs=[file_upload, session_id],
574
  outputs=[status_text]
575
  )
576
+
577
+ # Periodic cleanup of old sessions (optional)
578
+ if session_manager:
579
+ session_manager.cleanup_old_sessions()
580
 
581
  # Initialize models and launch interface
582
  initialize_fn()