UnityPaul commited on
Commit
f872bac
1 Parent(s): bc77749

Upload 3 files

Browse files
Files changed (3) hide show
  1. MiniLMv6.cs +61 -36
  2. MiniLMv6.sentis +2 -2
  3. info.json +2 -2
MiniLMv6.cs CHANGED
@@ -3,6 +3,7 @@ using UnityEngine;
3
  using Unity.Sentis;
4
  using System.IO;
5
  using System.Text;
 
6
 
7
  /*
8
  * Tiny Stories Inference Code
@@ -35,40 +36,43 @@ public class MiniLM : MonoBehaviour
35
  const int START_TOKEN = 101;
36
  const int END_TOKEN = 102;
37
 
38
- Ops ops;
39
- ITensorAllocator allocator;
40
-
41
  //Store the vocabulary
42
  string[] tokens;
43
 
44
- IWorker engine;
 
 
45
 
46
  void Start()
47
  {
48
- allocator = new TensorCachingAllocator();
49
- ops = WorkerFactory.CreateOps(backend, allocator);
50
-
51
  tokens = File.ReadAllLines(Application.streamingAssetsPath + "/vocab.txt");
52
 
53
- Model model = ModelLoader.Load(Application.streamingAssetsPath + "/MiniLMv6.sentis");
54
 
55
- engine = WorkerFactory.CreateWorker(backend, model);
56
 
57
  var tokens1 = GetTokens(string1);
58
  var tokens2 = GetTokens(string2);
59
 
60
- TensorFloat embedding1 = GetEmbedding(tokens1);
61
- TensorFloat embedding2 = GetEmbedding(tokens2);
 
 
62
 
63
- Debug.Log("Similarity Score: " + DotScore(embedding1, embedding2));
64
  }
65
 
66
- float DotScore(TensorFloat embedding1, TensorFloat embedding2)
67
  {
68
- using var prod = ops.Mul(embedding1, embedding2);
69
- using var dot = ops.ReduceSum(prod, new int[] { 1 }, false);
70
- dot.MakeReadable();
71
- return dot[0];
 
 
 
 
 
72
  }
73
 
74
  TensorFloat GetEmbedding(List<int> tokens)
@@ -85,31 +89,52 @@ public class MiniLM : MonoBehaviour
85
 
86
  var inputs = new Dictionary<string, Tensor>
87
  {
88
- {"input_ids",input_ids },
89
- {"token_type_ids", token_type_ids},
90
- {"attention_mask", attention_mask }
91
  };
92
 
93
  engine.Execute(inputs);
94
 
95
- var tokenEmbeddings = engine.PeekOutput("output") as TensorFloat;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
- return MeanPooling(tokenEmbeddings, attention_mask);
98
  }
99
 
100
  //Get average of token embeddings taking into account the attention mask
101
- TensorFloat MeanPooling(TensorFloat tokenEmbeddings, TensorInt attentonMask)
102
  {
103
- using var mask0 = attentonMask.ShallowReshape(attentonMask.shape.Unsqueeze(-1)) as TensorInt;
104
- using var maskExpanded = ops.Expand(mask0, tokenEmbeddings.shape);
105
- using var maskExpandedF = ops.Cast(maskExpanded, DataType.Float) as TensorFloat;
106
- using var D = ops.Mul(tokenEmbeddings, maskExpandedF);
107
- using var A = ops.ReduceSum(D, new[] { 1 }, false);
108
- using var C = ops.ReduceSum(maskExpandedF, new[] { 1 }, false);
109
- using var B = ops.Clip(C, 1e-9f, float.MaxValue);
110
- using var E = ops.Div(A, B);
111
- using var F = ops.ReduceL2(E, new[] { 1 }, true);
112
- return ops.Div(E, F);
 
 
 
 
 
 
113
  }
114
 
115
  List<int> GetTokens(string text)
@@ -150,9 +175,9 @@ public class MiniLM : MonoBehaviour
150
  }
151
 
152
  private void OnDestroy()
153
- {
 
154
  engine?.Dispose();
155
- ops?.Dispose();
156
- allocator?.Dispose();
157
  }
 
158
  }
 
3
  using Unity.Sentis;
4
  using System.IO;
5
  using System.Text;
6
+ using FF = Unity.Sentis.Functional;
7
 
8
  /*
9
  * Tiny Stories Inference Code
 
36
  const int START_TOKEN = 101;
37
  const int END_TOKEN = 102;
38
 
 
 
 
39
  //Store the vocabulary
40
  string[] tokens;
41
 
42
+ const int FEATURES = 384; //size of feature space
43
+
44
+ IWorker engine, dotScore;
45
 
46
  void Start()
47
  {
 
 
 
48
  tokens = File.ReadAllLines(Application.streamingAssetsPath + "/vocab.txt");
49
 
50
+ engine = CreateMLModel();
51
 
52
+ dotScore = CreateDotScoreModel();
53
 
54
  var tokens1 = GetTokens(string1);
55
  var tokens2 = GetTokens(string2);
56
 
57
+ using TensorFloat embedding1 = GetEmbedding(tokens1);
58
+ using TensorFloat embedding2 = GetEmbedding(tokens2);
59
+
60
+ float score = GetDotScore(embedding1, embedding2);
61
 
62
+ Debug.Log("Similarity Score: " + score);
63
  }
64
 
65
+ float GetDotScore(TensorFloat A, TensorFloat B)
66
  {
67
+ var inputs = new Dictionary<string, Tensor>()
68
+ {
69
+ { "input_0", A },
70
+ { "input_1", B }
71
+ };
72
+ dotScore.Execute(inputs);
73
+ var output = dotScore.PeekOutput() as TensorFloat;
74
+ output.CompleteOperationsAndDownload();
75
+ return output[0];
76
  }
77
 
78
  TensorFloat GetEmbedding(List<int> tokens)
 
89
 
90
  var inputs = new Dictionary<string, Tensor>
91
  {
92
+ {"input_0", input_ids },
93
+ {"input_1", attention_mask },
94
+ {"input_2", token_type_ids}
95
  };
96
 
97
  engine.Execute(inputs);
98
 
99
+ var output = engine.TakeOutputOwnership("output_0") as TensorFloat;
100
+ return output;
101
+ }
102
+
103
+ IWorker CreateMLModel()
104
+ {
105
+ Model model = ModelLoader.Load(Application.streamingAssetsPath + "/MiniLMv6.sentis");
106
+
107
+ Model modelWithMeanPooling = Functional.Compile(
108
+ (input_ids, attention_mask, token_type_ids) =>
109
+ {
110
+ var tokenEmbeddings = model.Forward(input_ids, attention_mask, token_type_ids)[0];
111
+ return MeanPooling(tokenEmbeddings, attention_mask);
112
+ },
113
+ (model.inputs[0], model.inputs[1], model.inputs[2])
114
+ );
115
 
116
+ return WorkerFactory.CreateWorker(backend, modelWithMeanPooling);
117
  }
118
 
119
  //Get average of token embeddings taking into account the attention mask
120
+ FunctionalTensor MeanPooling(FunctionalTensor tokenEmbeddings, FunctionalTensor attentionMask)
121
  {
122
+ var mask = attentionMask.Unsqueeze(-1).BroadcastTo(new[] { FEATURES }); //shape=(1,N,FEATURES)
123
+ var A = FF.ReduceSum(tokenEmbeddings * mask, 1, false); //shape=(1,FEATURES)
124
+ var B = A / (FF.ReduceSum(mask, 1, false) + 1e-9f); //shape=(1,FEATURES)
125
+ var C = FF.Sqrt(FF.ReduceSum(FF.Square(B), 1, true)); //shape=(1,FEATURES)
126
+ return B / C; //shape=(1,FEATURES)
127
+ }
128
+
129
+ IWorker CreateDotScoreModel()
130
+ {
131
+ Model dotScoreModel = Functional.Compile(
132
+ (input1, input2) => Functional.ReduceSum(input1 * input2, 1),
133
+ (InputDef.Float(new TensorShape(1, FEATURES)),
134
+ InputDef.Float(new TensorShape(1, FEATURES)))
135
+ );
136
+
137
+ return WorkerFactory.CreateWorker(backend, dotScoreModel);
138
  }
139
 
140
  List<int> GetTokens(string text)
 
175
  }
176
 
177
  private void OnDestroy()
178
+ {
179
+ dotScore?.Dispose();
180
  engine?.Dispose();
 
 
181
  }
182
+
183
  }
MiniLMv6.sentis CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:cd3cc73a83d426dd085c1839e587b6a7155ce91d6698f7ae2596a3f3cd02d1cf
3
- size 90952597
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9a2597ce9edce4c09b32e993b7f906cce91fceb2f461a597b974f71ee70453d
3
+ size 90898400
info.json CHANGED
@@ -8,7 +8,7 @@
8
  "data": [
9
  "vocab.txt"
10
  ],
11
- "version":[
12
- "1.3.0-pre.3"
13
  ]
14
  }
 
8
  "data": [
9
  "vocab.txt"
10
  ],
11
+ "version": [
12
+ "1.4.0"
13
  ]
14
  }