Royrotem100 commited on
Commit
123f4c7
โ€ข
1 Parent(s): 4c9d398

Add DictaLM 2.0 instruct model 6

Browse files
Files changed (1) hide show
  1. app.py +104 -68
app.py CHANGED
@@ -1,20 +1,25 @@
1
  import os
2
  import gradio as gr
3
  from http import HTTPStatus
4
- import openai
5
  from typing import Generator, List, Optional, Tuple, Dict
 
6
  from urllib.error import HTTPError
 
 
 
 
 
7
 
8
- API_URL = os.getenv('API_URL')
9
- API_KEY = os.getenv('API_KEY')
10
- CUSTOM_JS = os.getenv('CUSTOM_JS', None)
11
- oai_client = openai.OpenAI(api_key=API_KEY, base_url=API_URL)
12
 
13
  History = List[Tuple[str, str]]
14
  Messages = List[Dict[str, str]]
15
 
16
  def clear_session() -> History:
17
- return '', []
18
 
19
  def history_to_messages(history: History) -> Messages:
20
  messages = []
@@ -23,12 +28,51 @@ def history_to_messages(history: History) -> Messages:
23
  messages.append({'role': 'assistant', 'content': h[1].strip()})
24
  return messages
25
 
26
- def messages_to_history(messages: Messages) -> Tuple[str, History]:
27
  history = []
28
  for q, r in zip(messages[0::2], messages[1::2]):
29
- history.append([q['content'], r['content']])
30
  return history
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def model_chat(query: Optional[str], history: Optional[History]) -> Generator[Tuple[str, History], None, None]:
33
  if query is None:
34
  query = ''
@@ -36,77 +80,69 @@ def model_chat(query: Optional[str], history: Optional[History]) -> Generator[Tu
36
  history = []
37
  if not query.strip():
38
  return
39
- messages = history_to_messages(history)
40
- messages.append({'role': 'user', 'content': query.strip()})
41
- gen = oai_client.chat.completions.create(
42
- model='dicta-il/dictalm2.0-instruct',
43
- messages=messages,
44
- temperature=0.7,
45
- max_tokens=1024,
46
- top_p=0.9,
47
- stream=True
48
- )
49
- full_response = ''
50
- for completion in gen:
51
- text = completion.choices[0].delta.content
52
- full_response += text or ''
53
- yield full_response
54
-
55
  with gr.Blocks(css='''
56
  .gr-group {direction: rtl;}
57
  .chatbot{text-align:right;}
58
- .dicta-header {
59
- background-color: var(--input-background-fill); /* Replace with desired background color */
60
- border-radius: 10px;
61
- padding: 20px;
62
- text-align: center;
63
- display: flex;
64
- flex-direction: row;
65
- align-items: center;
66
- box-shadow: var(--block-shadow);
67
- border-color: var(--block-border-color);
68
- border-width: 1px;
69
- }
70
-
71
-
72
- @media (max-width: 768px) {
73
  .dicta-header {
74
- flex-direction: column; /* Change to vertical for mobile devices */
 
 
 
 
 
 
 
 
 
75
  }
76
- }
77
-
78
- .chatbot.prose {
79
- font-size: 1.2em;
80
- }
81
- .dicta-logo {
82
- width: 150px; /* Replace with actual logo width as desired */
83
- height: auto;
84
- margin-bottom: 20px;
85
- }
86
-
87
- .dicta-intro-text {
88
- margin-bottom: 20px;
89
- text-align: center;
90
- display: flex;
91
- flex-direction: column;
92
- align-items: center;
93
- width: 100%;
94
- font-size: 1.1em;
95
- }
96
 
97
- textarea {
98
- font-size: 1.2em;
99
- }
100
- ''', js=CUSTOM_JS) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  gr.Markdown("""
102
  <div class="dicta-header">
103
  <a href="">
104
- <img src="file/logo111.png" alt="Dicta Logo" class="dicta-logo">
105
  </a>
106
  <div class="dicta-intro-text">
107
- <h1>ืฆ'ืื˜ ืžืขืจื›ื™ - ื”ื“ื’ืžื” ืจืืฉื•ื ื™ืช</h1>
108
  <span dir='rtl'>ื‘ืจื•ื›ื™ื ื”ื‘ืื™ื ืœื“ืžื• ื”ืื™ื ื˜ืจืืงื˜ื™ื‘ื™ ื”ืจืืฉื•ืŸ. ื—ืงืจื• ืืช ื™ื›ื•ืœื•ืช ื”ืžื•ื“ืœ ื•ืจืื• ื›ื™ืฆื“ ื”ื•ื ื™ื›ื•ืœ ืœืกื™ื™ืข ืœื›ื ื‘ืžืฉื™ืžื•ืชื™ื›ื</span><br/>
109
- <span dir='rtl'>ื”ื“ืžื• ื ื›ืชื‘ ืขืœ ื™ื“ื™ ืกืจืŸ ืจื•ืขื™ ืจืชื ืชื•ืš ืฉื™ืžื•ืฉ ื‘ืžื•ื“ืœ ืฉืคื” ื“ื™ืงื˜ื” ืฉืคื•ืชื— ืขืœ ื™ื“ื™ ืžืคื"ืช</span><br/>
110
  </div>
111
  </div>
112
  """)
@@ -118,4 +154,4 @@ with gr.Blocks(css='''
118
  interface.textbox.text_align = 'right'
119
  interface.theme_css += '.gr-group {direction: rtl !important;}'
120
 
121
- demo.queue(api_open=False).launch(max_threads=20, share=False, allowed_paths=['dicta-logo.jpg'])
 
1
  import os
2
  import gradio as gr
3
  from http import HTTPStatus
 
4
  from typing import Generator, List, Optional, Tuple, Dict
5
+ import re
6
  from urllib.error import HTTPError
7
+ from flask import Flask, request, jsonify
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ import threading
10
+ import requests
11
+ import torch
12
 
13
+ # Load the model and tokenizer
14
+ model_name = "dicta-il/dictalm2.0-instruct"
15
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
17
 
18
  History = List[Tuple[str, str]]
19
  Messages = List[Dict[str, str]]
20
 
21
  def clear_session() -> History:
22
+ return []
23
 
24
  def history_to_messages(history: History) -> Messages:
25
  messages = []
 
28
  messages.append({'role': 'assistant', 'content': h[1].strip()})
29
  return messages
30
 
31
+ def messages_to_history(messages: Messages) -> History:
32
  history = []
33
  for q, r in zip(messages[0::2], messages[1::2]):
34
+ history.append((q['content'], r['content']))
35
  return history
36
 
37
+ # Flask app setup
38
+ app = Flask(__name__)
39
+
40
+ @app.route('/predict', methods=['POST'])
41
+ def predict():
42
+ data = request.json
43
+ input_text = data.get('text', '')
44
+
45
+ # Format the input text with instruction tokens
46
+ formatted_text = f"<s>[INST] {input_text} [/INST]"
47
+
48
+ # Tokenize the input
49
+ inputs = tokenizer(formatted_text, return_tensors='pt', padding=True, truncation=True)
50
+
51
+ # Generate the output
52
+ outputs = model.generate(
53
+ inputs['input_ids'],
54
+ attention_mask=inputs['attention_mask'],
55
+ max_length=1024,
56
+ temperature=0.7,
57
+ top_p=0.9,
58
+ do_sample=True,
59
+ pad_token_id=tokenizer.eos_token_id
60
+ )
61
+
62
+ # Decode the output
63
+ prediction = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(formatted_text, '').strip()
64
+
65
+ return jsonify({"prediction": prediction})
66
+
67
+ def run_flask():
68
+ app.run(host='0.0.0.0', port=5000)
69
+
70
+ def is_hebrew(text: str) -> bool:
71
+ return bool(re.search(r'[\u0590-\u05FF]', text))
72
+
73
+ # Run Flask in a separate thread
74
+ threading.Thread(target=run_flask).start()
75
+
76
  def model_chat(query: Optional[str], history: Optional[History]) -> Generator[Tuple[str, History], None, None]:
77
  if query is None:
78
  query = ''
 
80
  history = []
81
  if not query.strip():
82
  return
83
+
84
+ response = requests.post("http://127.0.0.1:5000/predict", json={"text": query.strip()})
85
+ if response.status_code == 200:
86
+ prediction = response.json().get("prediction", "")
87
+ history.append((query, prediction))
88
+ yield history
89
+ else:
90
+ yield history
91
+
 
 
 
 
 
 
 
92
  with gr.Blocks(css='''
93
  .gr-group {direction: rtl;}
94
  .chatbot{text-align:right;}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  .dicta-header {
96
+ background-color: var(--input-background-fill); /* Replace with desired background color */
97
+ border-radius: 10px;
98
+ padding: 20px;
99
+ text-align: center;
100
+ display: flex;
101
+ flex-direction: row;
102
+ align-items: center;
103
+ box-shadow: var(--block-shadow);
104
+ border-color: var(--block-border-color);
105
+ border-width: 1px;
106
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
+ @media (max-width: 768px) {
109
+ .dicta-header {
110
+ flex-direction: column; /* Change to vertical for mobile devices */
111
+ }
112
+ }
113
+
114
+ .chatbot.prose {
115
+ font-size: 1.2em;
116
+ }
117
+ .dicta-logo {
118
+ width: 150px; /* Replace with actual logo width as desired */
119
+ height: auto;
120
+ margin-bottom: 20px;
121
+ }
122
+
123
+ .dicta-intro-text {
124
+ margin-bottom: 20px;
125
+ text-align: center;
126
+ display: flex;
127
+ flex-direction: column;
128
+ align-items: center;
129
+ width: 100%;
130
+ font-size: 1.1em;
131
+ }
132
+
133
+ textarea {
134
+ font-size: 1.2em;
135
+ }
136
+ ''', js=None) as demo:
137
  gr.Markdown("""
138
  <div class="dicta-header">
139
  <a href="">
140
+ <img src="file/logo_am.png" alt="Dicta Logo" class="dicta-logo">
141
  </a>
142
  <div class="dicta-intro-text">
143
+ <h1>ื”ื“ื’ืžื” ืจืืฉื•ื ื™ืช</h1>
144
  <span dir='rtl'>ื‘ืจื•ื›ื™ื ื”ื‘ืื™ื ืœื“ืžื• ื”ืื™ื ื˜ืจืืงื˜ื™ื‘ื™ ื”ืจืืฉื•ืŸ. ื—ืงืจื• ืืช ื™ื›ื•ืœื•ืช ื”ืžื•ื“ืœ ื•ืจืื• ื›ื™ืฆื“ ื”ื•ื ื™ื›ื•ืœ ืœืกื™ื™ืข ืœื›ื ื‘ืžืฉื™ืžื•ืชื™ื›ื</span><br/>
145
+ <span dir='rtl'>ื”ื“ืžื• ื ื›ืชื‘ ืขืœ ื™ื“ื™ ืจื•ืขื™ ืจืชื ืชื•ืš ืฉื™ืžื•ืฉ ื‘ืžื•ื“ืœ ืฉืคื” ื“ื™ืงื˜ื” ืฉืคื•ืชื— ืขืœ ื™ื“ื™ ืžืคื"ืช</span><br/>
146
  </div>
147
  </div>
148
  """)
 
154
  interface.textbox.text_align = 'right'
155
  interface.theme_css += '.gr-group {direction: rtl !important;}'
156
 
157
+ demo.queue(api_open=False).launch(max_threads=20, share=False, allowed_paths=['logo_am.png'])