nmarafo commited on
Commit
266545e
1 Parent(s): 684c246

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -17
app.py CHANGED
@@ -9,9 +9,12 @@ token = os.environ.get("HF_TOKEN")
9
 
10
  model_id = "google/shieldgemma-2b"
11
 
12
- # use quantization to lower GPU usage
13
  bnb_config = BitsAndBytesConfig(
14
- load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
 
 
 
15
  )
16
 
17
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
@@ -20,7 +23,7 @@ model = AutoModelForCausalLM.from_pretrained(
20
  torch_dtype=torch.bfloat16,
21
  device_map="auto",
22
  quantization_config=bnb_config,
23
- token=token
24
  )
25
 
26
  # Función para generar el prompt incluyendo todas las políticas
@@ -63,16 +66,16 @@ def generar_prompt(message, idioma):
63
 
64
  @spaces.GPU(duration=150)
65
  # Función para procesar la respuesta y verificar políticas
66
- def respond(message, history, system_message, max_tokens, temperature, top_p, language):
67
  # Verificar políticas
68
  prompt = generar_prompt(message, language)
69
  inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
70
  outputs = model.generate(
71
  **inputs,
72
  max_new_tokens=50,
73
- temperature=temperature,
74
- top_p=top_p,
75
- do_sample=True,
76
  )
77
  response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
78
  first_word = response_text.strip().split()[0]
@@ -84,11 +87,12 @@ def respond(message, history, system_message, max_tokens, temperature, top_p, la
84
  violation = first_word
85
  else:
86
  violation = violation_keywords[1] # Asumir 'No' si no se puede determinar
87
- if violation == violation_keywords[0]: # '' o 'Yes'
88
  if language == "Español":
89
- return "Lo siento, pero no puedo ayudar con esa solicitud."
90
  else:
91
- return "I'm sorry, but I cannot assist with that request."
 
92
  else:
93
  # Generar respuesta al usuario
94
  if language == "Español":
@@ -115,20 +119,17 @@ with gr.Blocks() as demo:
115
  gr.Markdown("# Chatbot con Verificación de Políticas")
116
  language = gr.Dropdown(choices=["English", "Español"], value="English", label="Idioma/Language")
117
  system_message = gr.Textbox(value="You are a friendly Chatbot.", label="System message")
118
- #max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens")
119
- #temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
120
- #top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
121
  chatbot = gr.Chatbot()
122
  message = gr.Textbox(label="Your message")
123
  submit_button = gr.Button("Send")
124
 
125
- max_tokens=512
126
- temperature=0.7
127
- top_p=0.95
128
 
129
  def submit_message(user_message, chat_history, system_message, max_tokens, temperature, top_p, language):
130
  chat_history = chat_history + [[user_message, None]]
131
- assistant_reply = respond(user_message, chat_history, system_message, max_tokens, temperature, top_p, language)
132
  chat_history[-1][1] = assistant_reply
133
  return "", chat_history
134
 
 
9
 
10
  model_id = "google/shieldgemma-2b"
11
 
12
+ # Usar cuantización para reducir el uso de GPU
13
  bnb_config = BitsAndBytesConfig(
14
+ load_in_4bit=True,
15
+ bnb_4bit_use_double_quant=True,
16
+ bnb_4bit_quant_type="nf4",
17
+ bnb_4bit_compute_dtype=torch.bfloat16,
18
  )
19
 
20
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
 
23
  torch_dtype=torch.bfloat16,
24
  device_map="auto",
25
  quantization_config=bnb_config,
26
+ token=token,
27
  )
28
 
29
  # Función para generar el prompt incluyendo todas las políticas
 
66
 
67
  @spaces.GPU(duration=150)
68
  # Función para procesar la respuesta y verificar políticas
69
+ def respond(message, language, system_message, max_tokens, temperature, top_p):
70
  # Verificar políticas
71
  prompt = generar_prompt(message, language)
72
  inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
73
  outputs = model.generate(
74
  **inputs,
75
  max_new_tokens=50,
76
+ temperature=0.5,
77
+ top_p=1.0,
78
+ do_sample=False,
79
  )
80
  response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
81
  first_word = response_text.strip().split()[0]
 
87
  violation = first_word
88
  else:
89
  violation = violation_keywords[1] # Asumir 'No' si no se puede determinar
90
+ if violation in ['Yes', '']:
91
  if language == "Español":
92
+ violation_message = "Su pregunta viola las políticas aceptadas."
93
  else:
94
+ violation_message = "Your question violates the accepted policies."
95
+ return violation_message
96
  else:
97
  # Generar respuesta al usuario
98
  if language == "Español":
 
119
  gr.Markdown("# Chatbot con Verificación de Políticas")
120
  language = gr.Dropdown(choices=["English", "Español"], value="English", label="Idioma/Language")
121
  system_message = gr.Textbox(value="You are a friendly Chatbot.", label="System message")
 
 
 
122
  chatbot = gr.Chatbot()
123
  message = gr.Textbox(label="Your message")
124
  submit_button = gr.Button("Send")
125
 
126
+ max_tokens = 512
127
+ temperature = 0.7
128
+ top_p = 0.95
129
 
130
  def submit_message(user_message, chat_history, system_message, max_tokens, temperature, top_p, language):
131
  chat_history = chat_history + [[user_message, None]]
132
+ assistant_reply = respond(user_message, language, system_message, max_tokens, temperature, top_p)
133
  chat_history[-1][1] = assistant_reply
134
  return "", chat_history
135