dejanseo commited on
Commit
3b59afb
1 Parent(s): 92e35f3

Delete demo.py

Browse files
Files changed (1) hide show
  1. demo.py +0 -50
demo.py DELETED
@@ -1,50 +0,0 @@
1
- import tensorflow as tf
2
- import sentencepiece as spm
3
- import numpy as np
4
-
5
- # Paths to the model, tokenizer, and metadata
6
- tflite_model_path = r"C:\Users\dejan\AppData\Local\Google\Chrome SxS\User Data\optimization_guide_model_store\43\E6DC4029A1E4B4C1\EF94C116CBE73994\model.tflite"
7
- spm_model_path = r"C:\Users\dejan\AppData\Local\Google\Chrome SxS\User Data\optimization_guide_model_store\43\E6DC4029A1E4B4C1\EF94C116CBE73994\sentencepiece.model"
8
-
9
- # Load the SentencePiece tokenizer model
10
- sp = spm.SentencePieceProcessor()
11
- sp.load(spm_model_path)
12
-
13
- # Function to preprocess text input
14
- def preprocess_text(text, sp):
15
- # Tokenize and convert to ids
16
- input_ids = sp.encode(text, out_type=int)
17
- # Ensure input is the correct shape for the model
18
- return np.array(input_ids, dtype=np.int32).reshape(1, -1)
19
-
20
- # Load the TFLite model
21
- interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
22
- interpreter.allocate_tensors()
23
-
24
- # Get input and output details for model inference
25
- input_details = interpreter.get_input_details()
26
- output_details = interpreter.get_output_details()
27
-
28
- # Function to generate embeddings for a given text
29
- def generate_embeddings(text):
30
- # Preprocess text input
31
- input_data = preprocess_text(text, sp)
32
-
33
- # Adjust input tensor size if necessary
34
- interpreter.resize_tensor_input(input_details[0]['index'], input_data.shape)
35
- interpreter.allocate_tensors()
36
-
37
- # Set the input tensor with preprocessed data
38
- interpreter.set_tensor(input_details[0]['index'], input_data)
39
-
40
- # Run inference
41
- interpreter.invoke()
42
-
43
- # Extract the embedding output
44
- embedding = interpreter.get_tensor(output_details[0]['index'])
45
- return embedding
46
-
47
- # Example usage
48
- text = "Sample passage for embedding generation"
49
- embedding = generate_embeddings(text)
50
- print("Generated Embedding:", embedding)