nmarafo commited on
Commit
e436b6c
1 Parent(s): e2af991

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -51
app.py CHANGED
@@ -5,12 +5,10 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
5
  import torch
6
  import torch.nn.functional as F
7
 
8
- # Cargar el token de Hugging Face desde los secretos
9
  token = os.environ.get("HF_TOKEN")
10
 
11
- model_id = "google/shieldgemma-2b"
12
 
13
- # Usar cuantización para reducir el uso de GPU
14
  bnb_config = BitsAndBytesConfig(
15
  load_in_4bit=True,
16
  bnb_4bit_use_double_quant=True,
@@ -27,84 +25,72 @@ model = AutoModelForCausalLM.from_pretrained(
27
  token=token,
28
  )
29
 
30
- # Definir el 'system_message' directamente en el código
31
  system_message = "You are a friendly Chatbot."
32
 
33
- # Función para generar el prompt incluyendo las políticas
34
- def generar_prompt(user_prompt):
35
- safety_policy = """
36
- * "No Dangerous Content": The prompt shall not contain or seek generation of content that harms oneself and/or others (e.g., accessing or building firearms and explosive devices, promotion of terrorism, instructions for suicide).
37
- * "No Harassment": The prompt shall not contain or seek generation of content that is malicious, intimidating, bullying, or abusive content targeting another individual (e.g., physical threats, denial of tragic events, disparaging victims of violence).
38
- * "No Hate Speech": The prompt shall not contain or seek generation of content that targets identity and/or protected attributes (e.g., racial slurs, promotion of discrimination, calls to violence against protected groups) or dehumanizing/belittling/vilifying based on race, ethnicity, religion, disability, age, nationality, veteran status, sexual orientation, gender, gender identity, caste, or any other protected status.
39
- * "No Sexually Explicit Information": The prompt shall not contain or seek generation of content that contains references to sexual acts or other lewd content (e.g., sexually graphic descriptions, content aimed at causing arousal). Medical or scientific terms related to human anatomy or sex education are permitted.
40
- """
41
- prompt = f"""You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.
42
-
43
  <start_of_turn>
44
  Human Question: {user_prompt.strip()}
45
  <end_of_turn>
46
-
47
- Our safety principle is defined in the below:
48
-
49
- {safety_policy.strip()}
50
-
51
- Does the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.
52
  """
53
  return prompt
54
 
55
- @spaces.GPU(duration=150)
56
- # Función para procesar la respuesta y verificar políticas
57
- def respond(message, max_tokens, temperature, top_p):
58
- prompt = generar_prompt(message)
59
  inputs = tokenizer(prompt, return_tensors="pt")
60
 
61
  with torch.no_grad():
62
  outputs = model(**inputs)
63
  logits = outputs.logits
64
 
65
- # Obtener el vocabulario
66
  vocab = tokenizer.get_vocab()
67
- # IDs de los tokens 'Yes' y 'No'
68
  yes_token_id = vocab.get('Yes')
69
  no_token_id = vocab.get('No')
70
 
71
- # Comprobar si los tokens existen en el vocabulario
72
  if yes_token_id is None or no_token_id is None:
73
  raise ValueError("Los tokens 'Yes' o 'No' no se encontraron en el vocabulario.")
74
 
75
- # Extraer los logits para 'Yes' y 'No'
76
  selected_logits = logits[0, -1, [yes_token_id, no_token_id]]
77
-
78
- # Calcular las probabilidades con softmax
79
  probabilities = F.softmax(selected_logits, dim=0)
80
 
81
- # Probabilidad de 'Yes' y 'No'
82
  yes_probability = probabilities[0].item()
83
  no_probability = probabilities[1].item()
84
 
 
85
  print(f"Yes probability: {yes_probability}")
86
  print(f"No probability: {no_probability}")
87
 
88
- # Decidir si hay violación de políticas en función de la probabilidad de 'Yes'
89
- if yes_probability > no_probability:
90
- violation_message = "Your question violates the accepted policies."
91
- return violation_message
92
- else:
93
- # Generar respuesta al usuario
94
- assistant_prompt = f"{system_message}\nUser: {message}\nAssistant:"
95
- inputs = tokenizer(assistant_prompt, return_tensors="pt")
96
- outputs = model.generate(
97
- **inputs,
98
- max_new_tokens=max_tokens,
99
- temperature=temperature,
100
- top_p=top_p,
101
- do_sample=True,
102
- )
103
- assistant_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
104
- assistant_reply = assistant_response.split("Assistant:")[-1].strip()
105
- return assistant_reply
 
 
 
 
 
 
 
 
 
 
106
 
107
- # Crear la interfaz de Gradio usando Blocks
108
  with gr.Blocks() as demo:
109
  gr.Markdown("# Child-Safe-Chatbot (Experimental)")
110
  gr.Markdown("""
@@ -118,7 +104,7 @@ with gr.Blocks() as demo:
118
 
119
  The chatbot will inform the user if any violation occurs and, if not, will proceed to respond to the user's message in a friendly manner.
120
 
121
- Follow me on [Twitter](https://twitter.com/norbertomartnaf) and [GitHub](https://github.com/nmarafo) for more updates and projects!
122
  """)
123
 
124
  with gr.Accordion("Advanced", open=False):
 
5
  import torch
6
  import torch.nn.functional as F
7
 
 
8
  token = os.environ.get("HF_TOKEN")
9
 
10
+ model_id = "google/shieldgemma-9b"
11
 
 
12
  bnb_config = BitsAndBytesConfig(
13
  load_in_4bit=True,
14
  bnb_4bit_use_double_quant=True,
 
25
  token=token,
26
  )
27
 
 
28
  system_message = "You are a friendly Chatbot."
29
 
30
+ def generar_prompt(user_prompt, policy):
31
+ prompt = f"""You are a policy expert trying to help determine whether a user prompt is in violation of the following safety policy:
32
+ {policy}
 
 
 
 
 
 
 
33
  <start_of_turn>
34
  Human Question: {user_prompt.strip()}
35
  <end_of_turn>
36
+ Does the human question violate the above policy? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.
 
 
 
 
 
37
  """
38
  return prompt
39
 
40
+ def verificar_politica(message, policy):
41
+ prompt = generar_prompt(message, policy)
 
 
42
  inputs = tokenizer(prompt, return_tensors="pt")
43
 
44
  with torch.no_grad():
45
  outputs = model(**inputs)
46
  logits = outputs.logits
47
 
 
48
  vocab = tokenizer.get_vocab()
 
49
  yes_token_id = vocab.get('Yes')
50
  no_token_id = vocab.get('No')
51
 
 
52
  if yes_token_id is None or no_token_id is None:
53
  raise ValueError("Los tokens 'Yes' o 'No' no se encontraron en el vocabulario.")
54
 
 
55
  selected_logits = logits[0, -1, [yes_token_id, no_token_id]]
 
 
56
  probabilities = F.softmax(selected_logits, dim=0)
57
 
 
58
  yes_probability = probabilities[0].item()
59
  no_probability = probabilities[1].item()
60
 
61
+ print(f"Policy: {policy}")
62
  print(f"Yes probability: {yes_probability}")
63
  print(f"No probability: {no_probability}")
64
 
65
+ return yes_probability > no_probability
66
+
67
+ @spaces.GPU(duration=150)
68
+ def respond(message, max_tokens, temperature, top_p):
69
+ policies = [
70
+ "No Dangerous Content: The prompt shall not contain or seek generation of content that harms oneself and/or others (e.g., accessing or building firearms and explosive devices, promotion of terrorism, instructions for suicide).",
71
+ "No Harassment: The prompt shall not contain or seek generation of content that is malicious, intimidating, bullying, or abusive content targeting another individual (e.g., physical threats, denial of tragic events, disparaging victims of violence).",
72
+ "No Hate Speech: The prompt shall not contain or seek generation of content that targets identity and/or protected attributes (e.g., racial slurs, promotion of discrimination, calls to violence against protected groups) or dehumanizing/belittling/vilifying based on race, ethnicity, religion, disability, age, nationality, veteran status, sexual orientation, gender, gender identity, caste, or any other protected status.",
73
+ "No Sexually Explicit Information: The prompt shall not contain or seek generation of content that contains references to sexual acts or other lewd content (e.g., sexually graphic descriptions, content aimed at causing arousal). Medical or scientific terms related to human anatomy or sex education are permitted."
74
+ ]
75
+
76
+ for policy in policies:
77
+ if verificar_politica(message, policy):
78
+ violation_message = f"Your question violates the following policy: {policy.split(':')[0]}"
79
+ return violation_message
80
+
81
+ assistant_prompt = f"{system_message}\nUser: {message}\nAssistant:"
82
+ inputs = tokenizer(assistant_prompt, return_tensors="pt")
83
+ outputs = model.generate(
84
+ **inputs,
85
+ max_new_tokens=max_tokens,
86
+ temperature=temperature,
87
+ top_p=top_p,
88
+ do_sample=True,
89
+ )
90
+ assistant_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
91
+ assistant_reply = assistant_response.split("Assistant:")[-1].strip()
92
+ return assistant_reply
93
 
 
94
  with gr.Blocks() as demo:
95
  gr.Markdown("# Child-Safe-Chatbot (Experimental)")
96
  gr.Markdown("""
 
104
 
105
  The chatbot will inform the user if any violation occurs and, if not, will proceed to respond to the user's message in a friendly manner.
106
 
107
+ I'm Norberto Martín Afonso. Follow me on [Twitter](https://twitter.com/norbertomartnaf) and [GitHub](https://github.com/nmarafo) for more updates and projects!
108
  """)
109
 
110
  with gr.Accordion("Advanced", open=False):