fakeym commited on
Commit
3d97a67
1 Parent(s): a09d5ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -47
app.py CHANGED
@@ -1,48 +1,48 @@
1
- import string
2
- import gradio as gr
3
- import requests
4
- import torch
5
- from transformers import (
6
- AutoConfig,
7
- AutoModelForSequenceClassification,
8
- AutoTokenizer,
9
- )
10
-
11
- custom_labels = {0: "neg", 1: "pos"}
12
- model_dir = r'model\sst-2-english'
13
- # model = pipeline("sentiment-analysis",model=model_dir,device=0)
14
- # print(model("you are bad boy."))
15
- config = AutoConfig.from_pretrained(model_dir, num_labels=2, finetuning_task="text-classification")
16
- tokenizer = AutoTokenizer.from_pretrained(model_dir)
17
- model = AutoModelForSequenceClassification.from_pretrained(model_dir, config=config)
18
- model.config.id2label = custom_labels
19
- model.config.label2id = {v: k for k, v in custom_labels.items()}
20
- def inference(input_text):
21
- inputs = tokenizer.batch_encode_plus(
22
- [input_text],
23
- max_length=512,
24
- pad_to_max_length=True,
25
- truncation=True,
26
- padding="max_length",
27
- return_tensors="pt",
28
- )
29
-
30
- with torch.no_grad():
31
- logits = model(**inputs).logits
32
-
33
- predicted_class_id = logits.argmax().item()
34
- output = model.config.id2label[predicted_class_id]
35
- return output
36
-
37
- demo = gr.Interface(
38
- fn=inference,
39
- inputs=gr.Textbox(label="Input Text", scale=2, container=False),
40
- outputs=gr.Textbox(label="Output Label"),
41
- examples = [
42
- ["My last two weather pics from the storm on August 2nd. People packed up real fast after the temp dropped and winds picked up.", 1],
43
- ["Lying Clinton sinking! Donald Trump singing: Let's Make America Great Again!", 0],
44
- ],
45
- title="Tutorial: BERT-based Text Classificatioin",
46
- )
47
-
48
  demo.launch(debug=True)
 
1
+ import string
2
+ import gradio as gr
3
+ import requests
4
+ import torch
5
+ from transformers import (
6
+ AutoConfig,
7
+ AutoModelForSequenceClassification,
8
+ AutoTokenizer,
9
+ )
10
+
11
+ custom_labels = {0: "neg", 1: "pos"}
12
+ model_dir = r'model/sst-2-english'
13
+ # model = pipeline("sentiment-analysis",model=model_dir,device=0)
14
+ # print(model("you are bad boy."))
15
+ config = AutoConfig.from_pretrained(model_dir, num_labels=2, finetuning_task="text-classification")
16
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
17
+ model = AutoModelForSequenceClassification.from_pretrained(model_dir, config=config)
18
+ model.config.id2label = custom_labels
19
+ model.config.label2id = {v: k for k, v in custom_labels.items()}
20
+ def inference(input_text):
21
+ inputs = tokenizer.batch_encode_plus(
22
+ [input_text],
23
+ max_length=512,
24
+ pad_to_max_length=True,
25
+ truncation=True,
26
+ padding="max_length",
27
+ return_tensors="pt",
28
+ )
29
+
30
+ with torch.no_grad():
31
+ logits = model(**inputs).logits
32
+
33
+ predicted_class_id = logits.argmax().item()
34
+ output = model.config.id2label[predicted_class_id]
35
+ return output
36
+
37
+ demo = gr.Interface(
38
+ fn=inference,
39
+ inputs=gr.Textbox(label="Input Text", scale=2, container=False),
40
+ outputs=gr.Textbox(label="Output Label"),
41
+ examples = [
42
+ ["My last two weather pics from the storm on August 2nd. People packed up real fast after the temp dropped and winds picked up.", 1],
43
+ ["Lying Clinton sinking! Donald Trump singing: Let's Make America Great Again!", 0],
44
+ ],
45
+ title="Tutorial: BERT-based Text Classificatioin",
46
+ )
47
+
48
  demo.launch(debug=True)