shaunkhoo commited on
Commit
97f9a3c
1 Parent(s): 421bb51

feat: add config file and inference script

Browse files
Files changed (2) hide show
  1. config.json +91 -0
  2. inference.py +93 -0
config.json ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "description": "Binary classifier on harmful text in Singapore context",
3
+ "embedding": {
4
+ "tokenizer": "BAAI/bge-large-en-v1.5",
5
+ "model": "BAAI/bge-large-en-v1.5",
6
+ "max_length": 512,
7
+ "batch_size": 32
8
+ },
9
+ "classifier": {
10
+ "binary": {
11
+ "calibrated": true,
12
+ "threshold": {
13
+ "high_recall": 0.2,
14
+ "balanced": 0.5,
15
+ "high_precision": 0.8
16
+ },
17
+ "model_type": "ridge_classifier",
18
+ "model_fp": "models/lionguard-binary.onnx"
19
+ },
20
+ "hateful": {
21
+ "calibrated": false,
22
+ "threshold": {
23
+ "high_recall": 0.516,
24
+ "balanced": 0.827,
25
+ "high_precision": 1.254
26
+ },
27
+ "model_type": "ridge_classifier",
28
+ "model_fp": "models/lionguard-harassment.onnx"
29
+ },
30
+ "harassment": {
31
+ "calibrated": false,
32
+ "threshold": {
33
+ "high_recall": 1.326,
34
+ "balanced": 1.326,
35
+ "high_precision": 1.955
36
+ },
37
+ "model_type": "ridge_classifier",
38
+ "model_fp": "models/lionguard-harassment.onnx"
39
+ },
40
+ "public_harm": {
41
+ "calibrated": false,
42
+ "threshold": {
43
+ "high_recall": 0.953,
44
+ "balanced": 0.953,
45
+ "high_precision": 0.953
46
+ },
47
+ "model_type": "ridge_classifier",
48
+ "model_fp": "models/lionguard-public_harm.onnx"
49
+ },
50
+ "self_harm": {
51
+ "calibrated": false,
52
+ "threshold": {
53
+ "high_recall": 0.915,
54
+ "balanced": 0.915,
55
+ "high_precision": 0.915
56
+ },
57
+ "model_type": "ridge_classifier",
58
+ "model_fp": "models/lionguard-self_harm.onnx"
59
+ },
60
+ "sexual": {
61
+ "calibrated": false,
62
+ "threshold": {
63
+ "high_recall": 0.388,
64
+ "balanced": 0.500,
65
+ "high_precision": 0.702
66
+ },
67
+ "model_type": "ridge_classifier",
68
+ "model_fp": "models/lionguard-sexual.onnx"
69
+ },
70
+ "toxic": {
71
+ "calibrated": false,
72
+ "threshold": {
73
+ "high_recall": -0.089,
74
+ "balanced": 0.136,
75
+ "high_precision": 0.327
76
+ },
77
+ "model_type": "ridge_classifier",
78
+ "model_fp": "models/lionguard-toxic.onnx"
79
+ },
80
+ "violent": {
81
+ "calibrated": false,
82
+ "threshold": {
83
+ "high_recall": 0.317,
84
+ "balanced": 0.981,
85
+ "high_precision": 0.981
86
+ },
87
+ "model_type": "ridge_classifier",
88
+ "model_fp": "models/lionguard-violent.onnx"
89
+ }
90
+ }
91
+ }
inference.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModel
5
+ from huggingface_hub import hf_hub_download
6
+ import sys
7
+ import json
8
+ import onnxruntime as rt
9
+
10
+ # Download model config
11
+ repo_path = "govtech/lionguard-v1"
12
+ config_path = hf_hub_download(repo_id=repo_path, filename="config.json")
13
+ with open(config_path, 'r') as f:
14
+ config = json.load(f)
15
+
16
+ def get_embeddings(device, data):
17
+
18
+ # Load the model and tokenizer
19
+ tokenizer = AutoTokenizer.from_pretrained(config['embedding']['tokenizer'])
20
+ model = AutoModel.from_pretrained(config['embedding']['model'])
21
+ model.eval()
22
+ model.to(device)
23
+
24
+ # Generate the embeddings
25
+ batch_size = config['embedding']['batch_size']
26
+ num_batches = int(np.ceil(len(data)/batch_size))
27
+ output = []
28
+ for i in range(num_batches):
29
+ sentences = data[i*batch_size:(i+1)*batch_size]
30
+ encoded_input = tokenizer(sentences, max_length=config['embedding']['max_length'], padding=True, truncation=True, return_tensors='pt')
31
+ encoded_input.to(device)
32
+ with torch.no_grad():
33
+ model_output = model(**encoded_input)
34
+ sentence_embeddings = model_output[0][:, 0]
35
+ sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
36
+ output.extend(sentence_embeddings.cpu().numpy())
37
+
38
+ return np.array(output)
39
+
40
+ def predict(batch_text):
41
+
42
+ device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
43
+ embeddings = get_embeddings(device, batch_text)
44
+ embeddings_df = pd.DataFrame(embeddings)
45
+
46
+ # Prepare input data
47
+ X_input = np.array(embeddings_df, dtype=np.float32)
48
+
49
+ # Load the classifiers
50
+ results = {}
51
+ for category, details in config['classifier'].items():
52
+
53
+ # Download the classifier from HuggingFace hub
54
+ local_model_fp = hf_hub_download(repo_id = repo_path, filename = config['classifer'][category]['model_fp'])
55
+
56
+ # Run the inference
57
+ session = rt.InferenceSession(local_model_fp)
58
+ input_name = session.get_inputs()[0].name
59
+ outputs = session.run(None, {input_name: X_input})
60
+
61
+ # If calibrated, return only the prediction for the unsafe class
62
+ if config['classifier'][category]['calibrated']:
63
+ scores = [output[1] for output in outputs[1]]
64
+
65
+ # If not calibrated, we will only get a 1D array for the unsafe class
66
+ else:
67
+ scores = outputs[1].flatten()
68
+
69
+ # Generate the predictions depending on the recommended threshold score
70
+ results[category] = {
71
+ 'scores': scores,
72
+ 'predictions': {
73
+ 'high_recall': [1 if score >= config['classifier'][category]['threshold']['high_recall'] else 0 for score in scores],
74
+ 'balanced': [1 if score >= config['classifier'][category]['threshold']['balanced'] else 0 for score in scores],
75
+ 'high_precision': [1 if score >= config['classifier'][category]['threshold']['high_precision'] else 0 for score in scores]
76
+ }
77
+ }
78
+
79
+ return results
80
+
81
+ if __name__ == "__main__":
82
+
83
+ # Load the data
84
+ input_data = sys.argv[1]
85
+ batch_text = json.loads(input_data)
86
+
87
+ # Generate the scores and predictions
88
+ results = predict(batch_text)
89
+ for i in range(len(batch_text)):
90
+ print(f"Text: '{batch_text[i]}'")
91
+ for category in results.keys():
92
+ print(f"[Text {i+1}] {category} score: {results[category]['scores'][i]:.3f} | HR: {results[category]['predictions']['high_recall'][i]}, B: {results[category]['predictions']['balanced'][i]}, HP: {results[category]['predictions']['high_precision'][i]}")
93
+ print('---------------------------------------------')