nmarafo commited on
Commit
7809349
1 Parent(s): c917c90

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -89
app.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  import spaces
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
5
  import torch
 
6
 
7
  # Cargar el token de Hugging Face desde los secretos
8
  token = os.environ.get("HF_TOKEN")
@@ -26,92 +27,72 @@ model = AutoModelForCausalLM.from_pretrained(
26
  token=token,
27
  )
28
 
29
- # Función para generar el prompt incluyendo todas las políticas
30
- def generar_prompt(message, idioma):
31
- if idioma == "Español":
32
- politica_seguridad = """
33
- * "No al contenido peligroso": El mensaje no debe contener ni buscar la generación de contenido que cause daño a uno mismo o a otros (por ejemplo, acceder o fabricar armas de fuego o explosivos, promoción del terrorismo, instrucciones para el suicidio).
34
- * "No al acoso": El mensaje no debe contener ni buscar la generación de contenido que sea malicioso, intimidante, que fomente el bullying o abusivo dirigido a otra persona (por ejemplo, amenazas físicas, negación de eventos trágicos, menosprecio a víctimas de violencia).
35
- * "No al discurso de odio": El mensaje no debe contener ni buscar la generación de contenido que ataque la identidad o atributos protegidos (por ejemplo, insultos raciales, promoción de la discriminación, llamados a la violencia contra grupos protegidos) o que deshumanice, menosprecie o vilipendie por motivos de raza, etnia, religión, discapacidad, edad, nacionalidad, estatus de veterano, orientación sexual, género, identidad de género, casta, u otro estatus protegido.
36
- * "No al contenido sexual explícito": El mensaje no debe contener ni buscar la generación de contenido que haga referencia a actos sexuales u otro contenido lascivo (por ejemplo, descripciones gráficas sexuales, contenido destinado a causar excitación). Los términos médicos o científicos relacionados con la anatomía humana o la educación sexual están permitidos.
37
- """
38
- prompt = f"""Eres un experto en políticas tratando de determinar si un mensaje del usuario
39
- viola las políticas de seguridad definidas.
40
- <inicio_turno>
41
- Pregunta del usuario: {message.strip()}
42
- <fin_turno>
43
- Nuestro principio de seguridad se define a continuación:
44
- {politica_seguridad.strip()}
45
- ¿La pregunta del usuario viola alguno de los principios anteriores? Tu respuesta debe comenzar con
46
- 'Sí' o 'No'. Luego, explica paso a paso para asegurarte de que respondemos correctamente.
47
- """
48
- else:
49
- politica_seguridad = """
50
- * "No Dangerous Content": The prompt shall not contain or seek generation of content that harms oneself and/or others (e.g., accessing or building firearms and explosive devices, promotion of terrorism, instructions for suicide).
51
- * "No Harassment": The prompt shall not contain or seek generation of content that is malicious, intimidating, bullying, or abusive content targeting another individual (e.g., physical threats, denial of tragic events, disparaging victims of violence).
52
- * "No Hate Speech": The prompt shall not contain or seek generation of content that targets identity and/or protected attributes (e.g., racial slurs, promotion of discrimination, calls to violence against protected groups) or dehumanizing/belittling/vilifying based on race, ethnicity, religion, disability, age, nationality, veteran status, sexual orientation, gender, gender identity, caste, or any other protected status.
53
- * "No Sexually Explicit Information": The prompt shall not contain or seek generation of content that contains references to sexual acts or other lewd content (e.g., sexually graphic descriptions, content aimed at causing arousal). Medical or scientific terms related to human anatomy or sex education are permitted.
54
- """
55
- prompt = f"""You are a policy expert trying to help determine whether a user prompt is in
56
- violation of the defined safety policies.
57
- <start_of_turn>
58
- Human Question: {message.strip()}
59
- <end_of_turn>
60
- Our safety principles are defined below:
61
- {politica_seguridad.strip()}
62
- Does the human question violate any of the above principles? Your answer must start with
63
- 'Yes' or 'No'. Then walk through step by step to be sure we answer correctly.
64
- """
65
  return prompt
66
 
67
  @spaces.GPU(duration=150)
68
- def respond(message, language, system_message, max_tokens, temperature, top_p):
69
- prompt = generar_prompt(message, language)
70
- inputs = tokenizer(prompt, return_tensors="pt") # No especificar dispositivo aquí
71
- outputs = model.generate(
72
- **inputs,
73
- max_new_tokens=1,
74
- temperature=0.0,
75
- do_sample=False,
76
- return_dict_in_generate=True,
77
- output_scores=True,
78
- )
79
- # Obtener los logits del token generado
80
- logits = outputs.scores[0] # Solo un paso de generación
81
- # Obtener los IDs de los tokens "Yes" y "No"
82
- if language == "Español":
83
- yes_token_id = tokenizer.encode('Sí', add_special_tokens=False)[0]
84
- no_token_id = tokenizer.encode('No', add_special_tokens=False)[0]
85
- else:
86
- yes_token_id = tokenizer.encode('Yes', add_special_tokens=False)[0]
87
- no_token_id = tokenizer.encode('No', add_special_tokens=False)[0]
88
- # Extraer los logits para "Yes" y "No"
89
- selected_logits = logits[0, [yes_token_id, no_token_id]]
90
- # Calcular las probabilidades
91
- probabilities = torch.softmax(selected_logits, dim=0)
 
 
92
  yes_probability = probabilities[0].item()
93
  no_probability = probabilities[1].item()
94
-
95
- # Imprimir las probabilidades
96
  print(f"Yes probability: {yes_probability}")
97
  print(f"No probability: {no_probability}")
98
 
99
- # Decidir si hay violación de políticas
100
  if yes_probability > no_probability:
101
- print("Decisión: Yes (viola las políticas)")
102
- if language == "Español":
103
- violation_message = "Su pregunta viola las políticas aceptadas."
104
- else:
105
- violation_message = "Your question violates the accepted policies."
106
  return violation_message
107
  else:
108
- print("Decisión: No (no viola las políticas)")
109
  # Generar respuesta al usuario
110
- if language == "Español":
111
- assistant_prompt = f"{system_message}\nUsuario: {message}\nAsistente:"
112
- else:
113
- assistant_prompt = f"{system_message}\nUser: {message}\nAssistant:"
114
- inputs = tokenizer(assistant_prompt, return_tensors="pt").to("cpu")
115
  outputs = model.generate(
116
  **inputs,
117
  max_new_tokens=max_tokens,
@@ -120,20 +101,14 @@ def respond(message, language, system_message, max_tokens, temperature, top_p):
120
  do_sample=True,
121
  )
122
  assistant_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
123
- if language == "Español":
124
- assistant_reply = assistant_response.split("Asistente:")[-1].strip()
125
- else:
126
- assistant_reply = assistant_response.split("Assistant:")[-1].strip()
127
  return assistant_reply
128
 
129
-
130
  # Crear la interfaz de Gradio usando Blocks
131
  with gr.Blocks() as demo:
132
- gr.Markdown("# Chatbot con Verificación de Políticas")
133
- language = gr.Dropdown(choices=["English", "Español"], value="English", label="Idioma/Language")
134
- system_message = "You are a friendly Chatbot."
135
 
136
- with gr.Accordion("Avanzado", open=False):
137
  max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens")
138
  temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
139
  top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
@@ -142,23 +117,23 @@ with gr.Blocks() as demo:
142
  message = gr.Textbox(label="Your message")
143
  submit_button = gr.Button("Send")
144
 
145
- def submit_message(user_message, chat_history, max_tokens, temperature, top_p, language):
146
  chat_history = chat_history + [[user_message, None]]
147
- assistant_reply = respond(user_message, language, system_message, max_tokens, temperature, top_p)
 
 
148
  chat_history[-1][1] = assistant_reply
149
  return "", chat_history
150
 
151
-
152
  submit_button.click(
153
  submit_message,
154
- inputs=[message, chatbot, max_tokens, temperature, top_p, language],
155
  outputs=[message, chatbot],
156
  )
157
  message.submit(
158
  submit_message,
159
- inputs=[message, chatbot, max_tokens, temperature, top_p, language],
160
  outputs=[message, chatbot],
161
  )
162
 
163
-
164
  demo.launch(debug=True)
 
3
  import spaces
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
5
  import torch
6
+ import torch.nn.functional as F
7
 
8
  # Cargar el token de Hugging Face desde los secretos
9
  token = os.environ.get("HF_TOKEN")
 
27
  token=token,
28
  )
29
 
30
+ # Definir el 'system_message' directamente en el código
31
+ system_message = "You are a friendly Chatbot."
32
+
33
+ # Función para generar el prompt incluyendo las políticas
34
+ def generar_prompt(user_prompt):
35
+ safety_policy = """
36
+ * "No Dangerous Content": The prompt shall not contain or seek generation of content that harms oneself and/or others (e.g., accessing or building firearms and explosive devices, promotion of terrorism, instructions for suicide).
37
+ * "No Harassment": The prompt shall not contain or seek generation of content that is malicious, intimidating, bullying, or abusive content targeting another individual (e.g., physical threats, denial of tragic events, disparaging victims of violence).
38
+ * "No Hate Speech": The prompt shall not contain or seek generation of content that targets identity and/or protected attributes (e.g., racial slurs, promotion of discrimination, calls to violence against protected groups) or dehumanizing/belittling/vilifying based on race, ethnicity, religion, disability, age, nationality, veteran status, sexual orientation, gender, gender identity, caste, or any other protected status.
39
+ * "No Sexually Explicit Information": The prompt shall not contain or seek generation of content that contains references to sexual acts or other lewd content (e.g., sexually graphic descriptions, content aimed at causing arousal). Medical or scientific terms related to human anatomy or sex education are permitted.
40
+ """
41
+ prompt = f"""You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.
42
+
43
+ <start_of_turn>
44
+ Human Question: {user_prompt.strip()}
45
+ <end_of_turn>
46
+
47
+ Our safety principle is defined in the below:
48
+
49
+ {safety_policy.strip()}
50
+
51
+ Does the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.
52
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  return prompt
54
 
55
  @spaces.GPU(duration=150)
56
+ # Función para procesar la respuesta y verificar políticas
57
+ def respond(message, max_tokens, temperature, top_p):
58
+ prompt = generar_prompt(message)
59
+ inputs = tokenizer(prompt, return_tensors="pt")
60
+
61
+ with torch.no_grad():
62
+ outputs = model(**inputs)
63
+ logits = outputs.logits
64
+
65
+ # Obtener el vocabulario
66
+ vocab = tokenizer.get_vocab()
67
+ # IDs de los tokens 'Yes' y 'No'
68
+ yes_token_id = vocab.get('Yes')
69
+ no_token_id = vocab.get('No')
70
+
71
+ # Comprobar si los tokens existen en el vocabulario
72
+ if yes_token_id is None or no_token_id is None:
73
+ raise ValueError("Los tokens 'Yes' o 'No' no se encontraron en el vocabulario.")
74
+
75
+ # Extraer los logits para 'Yes' y 'No'
76
+ selected_logits = logits[0, -1, [yes_token_id, no_token_id]]
77
+
78
+ # Calcular las probabilidades con softmax
79
+ probabilities = F.softmax(selected_logits, dim=0)
80
+
81
+ # Probabilidad de 'Yes' y 'No'
82
  yes_probability = probabilities[0].item()
83
  no_probability = probabilities[1].item()
84
+
 
85
  print(f"Yes probability: {yes_probability}")
86
  print(f"No probability: {no_probability}")
87
 
88
+ # Decidir si hay violación de políticas en función de la probabilidad de 'Yes'
89
  if yes_probability > no_probability:
90
+ violation_message = "Your question violates the accepted policies."
 
 
 
 
91
  return violation_message
92
  else:
 
93
  # Generar respuesta al usuario
94
+ assistant_prompt = f"{system_message}\nUser: {message}\nAssistant:"
95
+ inputs = tokenizer(assistant_prompt, return_tensors="pt")
 
 
 
96
  outputs = model.generate(
97
  **inputs,
98
  max_new_tokens=max_tokens,
 
101
  do_sample=True,
102
  )
103
  assistant_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
104
+ assistant_reply = assistant_response.split("Assistant:")[-1].strip()
 
 
 
105
  return assistant_reply
106
 
 
107
  # Crear la interfaz de Gradio usando Blocks
108
  with gr.Blocks() as demo:
109
+ gr.Markdown("# Child-Safe-Chatbot")
 
 
110
 
111
+ with gr.Accordion("Advanced", open=False):
112
  max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens")
113
  temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
114
  top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
 
117
  message = gr.Textbox(label="Your message")
118
  submit_button = gr.Button("Send")
119
 
120
+ def submit_message(user_message, chat_history, max_tokens, temperature, top_p):
121
  chat_history = chat_history + [[user_message, None]]
122
+ assistant_reply = respond(
123
+ user_message, max_tokens, temperature, top_p
124
+ )
125
  chat_history[-1][1] = assistant_reply
126
  return "", chat_history
127
 
 
128
  submit_button.click(
129
  submit_message,
130
+ inputs=[message, chatbot, max_tokens, temperature, top_p],
131
  outputs=[message, chatbot],
132
  )
133
  message.submit(
134
  submit_message,
135
+ inputs=[message, chatbot, max_tokens, temperature, top_p],
136
  outputs=[message, chatbot],
137
  )
138
 
 
139
  demo.launch(debug=True)