Thaweewat commited on
Commit
89c1475
1 Parent(s): 6189e90

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +82 -2
README.md CHANGED
@@ -6,7 +6,87 @@ language:
6
  - th
7
  metrics:
8
  - f1
9
- pipeline_tag: text-classification
10
  tags:
11
  - roberta
12
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  - th
7
  metrics:
8
  - f1
 
9
  tags:
10
  - roberta
11
+ ---
12
+
13
+ # Traffy Complaint Classification
14
+
15
+ This model is trained to automatically classify types of traffic complaints in Thai text, aiming to reduce the need for manual classification by humans.
16
+
17
+ ### Model Details
18
+
19
+ Model Name: KDAI-NLP/wangchanberta-traffy-multi
20
+ Tokenizer: airesearch/wangchanberta-base-att-spm-uncased
21
+ License: Apache License 2.0
22
+
23
+ ### How to Use
24
+
25
+ ```python
26
+
27
+ !pip install sentencepiece
28
+
29
+ import torch
30
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
31
+ from torch.nn.functional import sigmoid
32
+ import json
33
+
34
+ # Target lists
35
+ target_list = [
36
+ 'ความสะอาด', 'สายไฟ', 'สะพาน', 'ถนน', 'น้ำท่วม',
37
+ 'ร้องเรียน', 'ท่อระบายน้ำ', 'ความปลอดภัย', 'คลอง', 'แสงสว่าง',
38
+ 'ทางเท้า', 'จราจร', 'กีดขวาง', 'การเดินทาง', 'เสียงรบกวน',
39
+ 'ต้นไม้', 'สัตว์จรจัด', 'เสนอแนะ', 'คนจรจัด', 'ห้องน้ำ',
40
+ 'ป้ายจราจร', 'สอบถาม', 'ป้าย', 'PM2.5'
41
+ ]
42
+
43
+ # Load tokenizer and model
44
+ tokenizer = AutoTokenizer.from_pretrained("airesearch/wangchanberta-base-att-spm-uncased")
45
+ model = AutoModelForSequenceClassification.from_pretrained("KDAI-NLP/wangchanberta-traffy-multi")
46
+
47
+ # Example text to classify
48
+ text = "ช่วยด้วยครับถนนน้ำท่วมอีกแล้ว ต้นไม้ก็ล้มขวางทาง กลับบ้านไม่ได้"
49
+
50
+ # Encode the text using the tokenizer
51
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=256)
52
+
53
+ # Get model predictions (logits)
54
+ with torch.no_grad():
55
+ logits = model(**inputs).logits
56
+
57
+ # Apply sigmoid function to convert logits to probabilities
58
+ probabilities = sigmoid(logits)
59
+
60
+ # Map probabilities to corresponding labels
61
+ probabilities = probabilities.squeeze().tolist()
62
+ label_probabilities = zip(target_list, probabilities)
63
+
64
+ # Print labels with probabilities
65
+ for label, probability in label_probabilities:
66
+ print(f"{label}: {probability:.4f}")
67
+
68
+ # Or JSON
69
+ # Create a dictionary for labels and probabilities
70
+ results_dict = {label: probability for label, probability in label_probabilities}
71
+
72
+ # Convert dictionary to JSON string
73
+ results_json = json.dumps(results_dict, ensure_ascii=False, indent=4)
74
+
75
+ # Print the JSON string
76
+ print(results_json)
77
+ ```
78
+
79
+ ## Training Details
80
+
81
+ The model was trained on traffic complaint data API (included stopwords) using the airesearch/wangchanberta-base-att-spm-uncased base model. This is a multi-label classification task with a total of 24 classes.
82
+
83
+ ## Training Scores
84
+
85
+ | Model | Stopword | Epoch | Training Loss | Validation Loss | F1 | Accuracy |
86
+ | ---------------------------------- | -------- | ----- | ------------- | --------------- | ------- | -------- |
87
+ | wangchanberta-base-att-spm-uncased | Included | 0 | 0.0322 | 0.034822 | 0.7015 | 0.7569 |
88
+ | wangchanberta-base-att-spm-uncased | Included | 2 | 0.0207 | 0.026364 | 0.8405 | 0.7821 |
89
+ | wangchanberta-base-att-spm-uncased | Included | 4 | 0.0165 | 0.025142 | 0.8458 | 0.7934 |
90
+
91
+
92
+ Feel free to customize the README further if needed.