Spaces:
Sleeping
Sleeping
File size: 5,967 Bytes
99158db 7f034b6 29c2698 ad68c51 280953c 7809349 3ae76eb 7f034b6 ad41af2 e436b6c ad41af2 e94a870 266545e e94a870 ad68c51 ad41af2 e94a870 266545e ad41af2 7f034b6 7809349 e436b6c 7809349 e436b6c 7809349 7f034b6 ad5faaa e436b6c e0a6db1 7809349 b4e5bee 7809349 e436b6c b4e5bee e436b6c 9db9902 e436b6c 3c99aa7 e436b6c ad68c51 ec15066 744c734 e436b6c 744c734 b4e5bee 7809349 b4e5bee ad68c51 7809349 b4e5bee 7809349 b4e5bee 77a6d05 ad68c51 7809349 ad68c51 7809349 ad68c51 7f034b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import gradio as gr
import os
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import torch.nn.functional as F
token = os.environ.get("HF_TOKEN")
model_id = "google/shieldgemma-9b"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=bnb_config,
token=token,
)
system_message = "You are a friendly Chatbot."
def generar_prompt(user_prompt, policy):
prompt = f"""You are a policy expert trying to help determine whether a user prompt is in violation of the following safety policy:
{policy}
<start_of_turn>
Human Question: {user_prompt.strip()}
<end_of_turn>
Does the human question violate the above policy? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.
"""
return prompt
def verificar_politica(message, policy):
prompt = generar_prompt(message, policy)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
vocab = tokenizer.get_vocab()
yes_token_id = vocab.get('Yes')
no_token_id = vocab.get('No')
if yes_token_id is None or no_token_id is None:
raise ValueError("Los tokens 'Yes' o 'No' no se encontraron en el vocabulario.")
selected_logits = logits[0, -1, [yes_token_id, no_token_id]]
probabilities = F.softmax(selected_logits, dim=0)
yes_probability = probabilities[0].item()
no_probability = probabilities[1].item()
print(f"Policy: {policy}")
print(f"Yes probability: {yes_probability}")
print(f"No probability: {no_probability}")
return yes_probability > no_probability
@spaces.GPU(duration=150)
def respond(message, max_tokens, temperature, top_p):
policies = [
"No Dangerous Content: The prompt shall not contain or seek generation of content that harms oneself and/or others (e.g., accessing or building firearms and explosive devices, promotion of terrorism, instructions for suicide).",
"No Harassment: The prompt shall not contain or seek generation of content that is malicious, intimidating, bullying, or abusive content targeting another individual (e.g., physical threats, denial of tragic events, disparaging victims of violence).",
"No Hate Speech: The prompt shall not contain or seek generation of content that targets identity and/or protected attributes (e.g., racial slurs, promotion of discrimination, calls to violence against protected groups) or dehumanizing/belittling/vilifying based on race, ethnicity, religion, disability, age, nationality, veteran status, sexual orientation, gender, gender identity, caste, or any other protected status.",
"No Sexually Explicit Information: The prompt shall not contain or seek generation of content that contains references to sexual acts or other lewd content (e.g., sexually graphic descriptions, content aimed at causing arousal). Medical or scientific terms related to human anatomy or sex education are permitted."
]
for policy in policies:
if verificar_politica(message, policy):
violation_message = f"Your question violates the following policy: {policy.split(':')[0]}"
return violation_message
assistant_prompt = f"{system_message}\nUser: {message}\nAssistant:"
inputs = tokenizer(assistant_prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
)
assistant_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
assistant_reply = assistant_response.split("Assistant:")[-1].strip()
return assistant_reply
with gr.Blocks() as demo:
gr.Markdown("# Child-Safe-Chatbot (Experimental)")
gr.Markdown("""
### Description
This chatbot is designed to assist users while ensuring that all interactions comply with defined safety policies. It checks whether user inputs violate any of the following categories:
- Dangerous Content
- Harassment
- Hate Speech
- Sexually Explicit Information
The chatbot will inform the user if any violation occurs and, if not, will proceed to respond to the user's message in a friendly manner.
I'm Norberto Martín Afonso. Follow me on [Twitter](https://twitter.com/norbertomartnaf) and [GitHub](https://github.com/nmarafo) for more updates and projects!
""")
with gr.Accordion("Advanced", open=False):
max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens")
temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
chatbot = gr.Chatbot()
message = gr.Textbox(label="Your message")
submit_button = gr.Button("Send")
def submit_message(user_message, chat_history, max_tokens, temperature, top_p):
chat_history = chat_history + [[user_message, None]]
assistant_reply = respond(
user_message, max_tokens, temperature, top_p
)
chat_history[-1][1] = assistant_reply
return "", chat_history
submit_button.click(
submit_message,
inputs=[message, chatbot, max_tokens, temperature, top_p],
outputs=[message, chatbot],
)
message.submit(
submit_message,
inputs=[message, chatbot, max_tokens, temperature, top_p],
outputs=[message, chatbot],
)
demo.launch(debug=True)
|