dejanseo commited on
Commit
bf51e28
1 Parent(s): e44666c

Upload demo.py

Browse files
Files changed (1) hide show
  1. demo.py +50 -0
demo.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import sentencepiece as spm
3
+ import numpy as np
4
+
5
+ # Paths to the model, tokenizer, and metadata
6
+ tflite_model_path = r"C:\Users\dejan\AppData\Local\Google\Chrome SxS\User Data\optimization_guide_model_store\43\E6DC4029A1E4B4C1\EF94C116CBE73994\model.tflite"
7
+ spm_model_path = r"C:\Users\dejan\AppData\Local\Google\Chrome SxS\User Data\optimization_guide_model_store\43\E6DC4029A1E4B4C1\EF94C116CBE73994\sentencepiece.model"
8
+
9
+ # Load the SentencePiece tokenizer model
10
+ sp = spm.SentencePieceProcessor()
11
+ sp.load(spm_model_path)
12
+
13
+ # Function to preprocess text input
14
+ def preprocess_text(text, sp):
15
+ # Tokenize and convert to ids
16
+ input_ids = sp.encode(text, out_type=int)
17
+ # Ensure input is the correct shape for the model
18
+ return np.array(input_ids, dtype=np.int32).reshape(1, -1)
19
+
20
+ # Load the TFLite model
21
+ interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
22
+ interpreter.allocate_tensors()
23
+
24
+ # Get input and output details for model inference
25
+ input_details = interpreter.get_input_details()
26
+ output_details = interpreter.get_output_details()
27
+
28
+ # Function to generate embeddings for a given text
29
+ def generate_embeddings(text):
30
+ # Preprocess text input
31
+ input_data = preprocess_text(text, sp)
32
+
33
+ # Adjust input tensor size if necessary
34
+ interpreter.resize_tensor_input(input_details[0]['index'], input_data.shape)
35
+ interpreter.allocate_tensors()
36
+
37
+ # Set the input tensor with preprocessed data
38
+ interpreter.set_tensor(input_details[0]['index'], input_data)
39
+
40
+ # Run inference
41
+ interpreter.invoke()
42
+
43
+ # Extract the embedding output
44
+ embedding = interpreter.get_tensor(output_details[0]['index'])
45
+ return embedding
46
+
47
+ # Example usage
48
+ text = "Sample passage for embedding generation"
49
+ embedding = generate_embeddings(text)
50
+ print("Generated Embedding:", embedding)