Alfasign commited on
Commit
4377c12
1 Parent(s): ecc0d8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -116
app.py CHANGED
@@ -1,120 +1,19 @@
1
- from typing import Any, Dict, Tuple
2
- import warnings
3
 
4
- import torch
5
- from transformers import AutoModelForCausalLM, AutoTokenizer
6
- from transformers import (
7
- StoppingCriteria,
8
- StoppingCriteriaList,
9
- TextIteratorStreamer,
10
- )
11
 
 
 
 
 
 
12
 
13
- INSTRUCTION_KEY = "### Instruction:"
14
- RESPONSE_KEY = "### Response:"
15
- END_KEY = "### End"
16
- INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
17
- PROMPT_FOR_GENERATION_FORMAT = """{intro}
18
- {instruction_key}
19
- {instruction}
20
- {response_key}
21
- """.format(
22
- intro=INTRO_BLURB,
23
- instruction_key=INSTRUCTION_KEY,
24
- instruction="{instruction}",
25
- response_key=RESPONSE_KEY,
26
- )
27
 
28
-
29
- class InstructionTextGenerationPipeline:
30
- def __init__(
31
- self,
32
- model_name,
33
- torch_dtype=torch.bfloat16,
34
- trust_remote_code=True,
35
- use_auth_token=None,
36
- ) -> None:
37
- self.model = AutoModelForCausalLM.from_pretrained(
38
- model_name,
39
- torch_dtype=torch_dtype,
40
- trust_remote_code=trust_remote_code,
41
- use_auth_token=use_auth_token,
42
- )
43
-
44
- tokenizer = AutoTokenizer.from_pretrained(
45
- model_name,
46
- trust_remote_code=trust_remote_code,
47
- use_auth_token=use_auth_token,
48
- )
49
- if tokenizer.pad_token_id is None:
50
- warnings.warn(
51
- "pad_token_id is not set for the tokenizer. Using eos_token_id as pad_token_id."
52
- )
53
- tokenizer.pad_token = tokenizer.eos_token
54
- tokenizer.padding_side = "left"
55
- self.tokenizer = tokenizer
56
-
57
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
- self.model.eval()
59
- self.model.to(device=device, dtype=torch_dtype)
60
-
61
- self.generate_kwargs = {
62
- "temperature": 0.1,
63
- "top_p": 0.92,
64
- "top_k": 0,
65
- "max_new_tokens": 1024,
66
- "use_cache": True,
67
- "do_sample": True,
68
- "eos_token_id": self.tokenizer.eos_token_id,
69
- "pad_token_id": self.tokenizer.pad_token_id,
70
- "repetition_penalty": 1.1, # 1.0 means no penalty, > 1.0 means penalty, 1.2 from CTRL paper
71
- }
72
-
73
- def format_instruction(self, instruction):
74
- return PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
75
-
76
- def __call__(
77
- self, instruction: str, **generate_kwargs: Dict[str, Any]
78
- ) -> Tuple[str, str, float]:
79
- s = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
80
- input_ids = self.tokenizer(s, return_tensors="pt").input_ids
81
- input_ids = input_ids.to(self.model.device)
82
- gkw = {**self.generate_kwargs, **generate_kwargs}
83
- with torch.no_grad():
84
- output_ids = self.model.generate(input_ids, **gkw)
85
- # Slice the output_ids tensor to get only new tokens
86
- new_tokens = output_ids[0, len(input_ids[0]) :]
87
- output_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
88
- return output_text
89
-
90
- # Initialize the model and tokenizer
91
- generate = InstructionTextGenerationPipeline(
92
- "mosaicml/mpt-7b-instruct",
93
- torch_dtype=torch.bfloat16,
94
- trust_remote_code=True,
95
- )
96
- stop_token_ids = generate.tokenizer.convert_tokens_to_ids(["<|endoftext|>"])
97
-
98
-
99
- # Define a custom stopping criteria
100
- class StopOnTokens(StoppingCriteria):
101
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
102
- for stop_id in stop_token_ids:
103
- if input_ids[0][-1] == stop_id:
104
- return True
105
- return False
106
-
107
- """### The prompt & response"""
108
-
109
- import json
110
- import textwrap
111
-
112
- def get_prompt(instruction):
113
- prompt_template = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:"
114
- return prompt_template
115
-
116
- # print(get_prompt('What is the meaning of life?'))
117
-
118
- def parse_text(text):
119
- wrapped_text = textwrap.fill(text, width=100)
120
- print(wrapped_text +'\n\n')
 
1
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, TextClassificationPipeline
 
2
 
3
+ MODEL_PATH = "results/checkpoint-6000/" # Ändern Sie dies entsprechend
 
 
 
 
 
 
4
 
5
+ def load_model():
6
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
7
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
8
+ pipeline = TextClassificationPipeline(model=model, tokenizer=tokenizer)
9
+ return pipeline
10
 
11
+ def classify_text(text):
12
+ pipeline = load_model()
13
+ result = pipeline(text)
14
+ return result
 
 
 
 
 
 
 
 
 
 
15
 
16
+ if __name__ == "__main__":
17
+ text = input("Geben Sie einen Text ein: ")
18
+ result = classify_text(text)
19
+ print(result)