sentis-MiniLM-v6 / MiniLMv6.cs
UnityPaul's picture
Upload 3 files
f872bac verified
raw
history blame contribute delete
No virus
5.47 kB
using System.Collections.Generic;
using UnityEngine;
using Unity.Sentis;
using System.IO;
using System.Text;
using FF = Unity.Sentis.Functional;
/*
* Tiny Stories Inference Code
* ===========================
*
* Put this script on the Main Camera
*
* In Assets/StreamingAssets put:
*
* MiniLMv6.sentis
* vocab.txt
*
* Install package com.unity.sentis
*
*/
public class MiniLM : MonoBehaviour
{
const BackendType backend = BackendType.GPUCompute;
string string1 = "That is a happy person"; // similarity = 1
//Choose a string to comapre string1 to:
string string2 = "That is a happy dog"; // similarity = 0.695
//string string2 = "That is a very happy person"; // similarity = 0.943
//string string2 = "Today is a sunny day"; // similarity = 0.257
//Special tokens
const int START_TOKEN = 101;
const int END_TOKEN = 102;
//Store the vocabulary
string[] tokens;
const int FEATURES = 384; //size of feature space
IWorker engine, dotScore;
void Start()
{
tokens = File.ReadAllLines(Application.streamingAssetsPath + "/vocab.txt");
engine = CreateMLModel();
dotScore = CreateDotScoreModel();
var tokens1 = GetTokens(string1);
var tokens2 = GetTokens(string2);
using TensorFloat embedding1 = GetEmbedding(tokens1);
using TensorFloat embedding2 = GetEmbedding(tokens2);
float score = GetDotScore(embedding1, embedding2);
Debug.Log("Similarity Score: " + score);
}
float GetDotScore(TensorFloat A, TensorFloat B)
{
var inputs = new Dictionary<string, Tensor>()
{
{ "input_0", A },
{ "input_1", B }
};
dotScore.Execute(inputs);
var output = dotScore.PeekOutput() as TensorFloat;
output.CompleteOperationsAndDownload();
return output[0];
}
TensorFloat GetEmbedding(List<int> tokens)
{
int N = tokens.Count;
using var input_ids = new TensorInt(new TensorShape(1, N), tokens.ToArray());
using var token_type_ids = new TensorInt(new TensorShape(1, N), new int[N]);
int[] mask = new int[N];
for (int i = 0; i < mask.Length; i++)
{
mask[i] = 1;
}
using var attention_mask = new TensorInt(new TensorShape(1, N), mask);
var inputs = new Dictionary<string, Tensor>
{
{"input_0", input_ids },
{"input_1", attention_mask },
{"input_2", token_type_ids}
};
engine.Execute(inputs);
var output = engine.TakeOutputOwnership("output_0") as TensorFloat;
return output;
}
IWorker CreateMLModel()
{
Model model = ModelLoader.Load(Application.streamingAssetsPath + "/MiniLMv6.sentis");
Model modelWithMeanPooling = Functional.Compile(
(input_ids, attention_mask, token_type_ids) =>
{
var tokenEmbeddings = model.Forward(input_ids, attention_mask, token_type_ids)[0];
return MeanPooling(tokenEmbeddings, attention_mask);
},
(model.inputs[0], model.inputs[1], model.inputs[2])
);
return WorkerFactory.CreateWorker(backend, modelWithMeanPooling);
}
//Get average of token embeddings taking into account the attention mask
FunctionalTensor MeanPooling(FunctionalTensor tokenEmbeddings, FunctionalTensor attentionMask)
{
var mask = attentionMask.Unsqueeze(-1).BroadcastTo(new[] { FEATURES }); //shape=(1,N,FEATURES)
var A = FF.ReduceSum(tokenEmbeddings * mask, 1, false); //shape=(1,FEATURES)
var B = A / (FF.ReduceSum(mask, 1, false) + 1e-9f); //shape=(1,FEATURES)
var C = FF.Sqrt(FF.ReduceSum(FF.Square(B), 1, true)); //shape=(1,FEATURES)
return B / C; //shape=(1,FEATURES)
}
IWorker CreateDotScoreModel()
{
Model dotScoreModel = Functional.Compile(
(input1, input2) => Functional.ReduceSum(input1 * input2, 1),
(InputDef.Float(new TensorShape(1, FEATURES)),
InputDef.Float(new TensorShape(1, FEATURES)))
);
return WorkerFactory.CreateWorker(backend, dotScoreModel);
}
List<int> GetTokens(string text)
{
//split over whitespace
string[] words = text.ToLower().Split(null);
var ids = new List<int>
{
START_TOKEN
};
string s = "";
foreach (var word in words)
{
int start = 0;
for(int i = word.Length; i >= 0;i--)
{
string subword = start == 0 ? word.Substring(start, i) : "##" + word.Substring(start, i-start);
int index = System.Array.IndexOf(tokens, subword);
if (index >= 0)
{
ids.Add(index);
s += subword + " ";
if (i == word.Length) break;
start = i;
i = word.Length + 1;
}
}
}
ids.Add(END_TOKEN);
Debug.Log("Tokenized sentece = " + s);
return ids;
}
private void OnDestroy()
{
dotScore?.Dispose();
engine?.Dispose();
}
}