File size: 5,967 Bytes
99158db
7f034b6
29c2698
ad68c51
280953c
7809349
3ae76eb
7f034b6
ad41af2
e436b6c
ad41af2
e94a870
266545e
 
 
 
e94a870
 
ad68c51
ad41af2
 
e94a870
 
 
266545e
ad41af2
7f034b6
7809349
 
e436b6c
 
 
7809349
 
 
e436b6c
7809349
7f034b6
ad5faaa
e436b6c
 
e0a6db1
7809349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4e5bee
 
7809349
e436b6c
b4e5bee
 
 
e436b6c
 
9db9902
e436b6c
 
 
 
 
 
 
 
 
 
 
 
 
 
3c99aa7
e436b6c
 
 
 
 
 
 
 
 
 
 
ad68c51
ec15066
744c734
 
 
 
 
 
 
 
 
 
 
e436b6c
744c734
b4e5bee
7809349
b4e5bee
 
 
 
ad68c51
 
 
 
7809349
b4e5bee
7809349
 
 
b4e5bee
77a6d05
ad68c51
 
 
7809349
ad68c51
 
 
 
7809349
ad68c51
 
7f034b6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import gradio as gr
import os
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import torch.nn.functional as F

token = os.environ.get("HF_TOKEN")

model_id = "google/shieldgemma-9b"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    quantization_config=bnb_config,
    token=token,
)

system_message = "You are a friendly Chatbot."

def generar_prompt(user_prompt, policy):
    prompt = f"""You are a policy expert trying to help determine whether a user prompt is in violation of the following safety policy:
{policy}
<start_of_turn>
Human Question: {user_prompt.strip()}
<end_of_turn>
Does the human question violate the above policy? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.
"""
    return prompt

def verificar_politica(message, policy):
    prompt = generar_prompt(message, policy)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits

    vocab = tokenizer.get_vocab()
    yes_token_id = vocab.get('Yes')
    no_token_id = vocab.get('No')

    if yes_token_id is None or no_token_id is None:
        raise ValueError("Los tokens 'Yes' o 'No' no se encontraron en el vocabulario.")

    selected_logits = logits[0, -1, [yes_token_id, no_token_id]]
    probabilities = F.softmax(selected_logits, dim=0)

    yes_probability = probabilities[0].item()
    no_probability = probabilities[1].item()

    print(f"Policy: {policy}")
    print(f"Yes probability: {yes_probability}")
    print(f"No probability: {no_probability}")

    return yes_probability > no_probability

@spaces.GPU(duration=150)
def respond(message, max_tokens, temperature, top_p):
    policies = [
        "No Dangerous Content: The prompt shall not contain or seek generation of content that harms oneself and/or others (e.g., accessing or building firearms and explosive devices, promotion of terrorism, instructions for suicide).",
        "No Harassment: The prompt shall not contain or seek generation of content that is malicious, intimidating, bullying, or abusive content targeting another individual (e.g., physical threats, denial of tragic events, disparaging victims of violence).",
        "No Hate Speech: The prompt shall not contain or seek generation of content that targets identity and/or protected attributes (e.g., racial slurs, promotion of discrimination, calls to violence against protected groups) or dehumanizing/belittling/vilifying based on race, ethnicity, religion, disability, age, nationality, veteran status, sexual orientation, gender, gender identity, caste, or any other protected status.",
        "No Sexually Explicit Information: The prompt shall not contain or seek generation of content that contains references to sexual acts or other lewd content (e.g., sexually graphic descriptions, content aimed at causing arousal). Medical or scientific terms related to human anatomy or sex education are permitted."
    ]

    for policy in policies:
        if verificar_politica(message, policy):
            violation_message = f"Your question violates the following policy: {policy.split(':')[0]}"
            return violation_message

    assistant_prompt = f"{system_message}\nUser: {message}\nAssistant:"
    inputs = tokenizer(assistant_prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(
        **inputs,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=True,
    )
    assistant_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    assistant_reply = assistant_response.split("Assistant:")[-1].strip()
    return assistant_reply

with gr.Blocks() as demo:
    gr.Markdown("# Child-Safe-Chatbot (Experimental)")
    gr.Markdown("""
    ### Description
    This chatbot is designed to assist users while ensuring that all interactions comply with defined safety policies. It checks whether user inputs violate any of the following categories:
    
    - Dangerous Content
    - Harassment
    - Hate Speech
    - Sexually Explicit Information
    
    The chatbot will inform the user if any violation occurs and, if not, will proceed to respond to the user's message in a friendly manner.
    
    I'm Norberto Martín Afonso. Follow me on [Twitter](https://twitter.com/norbertomartnaf) and [GitHub](https://github.com/nmarafo) for more updates and projects!
    """)
    
    with gr.Accordion("Advanced", open=False):
        max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens")
        temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
        top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
    
    chatbot = gr.Chatbot()
    message = gr.Textbox(label="Your message")
    submit_button = gr.Button("Send")

    def submit_message(user_message, chat_history, max_tokens, temperature, top_p):
        chat_history = chat_history + [[user_message, None]]
        assistant_reply = respond(
            user_message, max_tokens, temperature, top_p
        )
        chat_history[-1][1] = assistant_reply
        return "", chat_history

    submit_button.click(
        submit_message,
        inputs=[message, chatbot, max_tokens, temperature, top_p],
        outputs=[message, chatbot],
    )
    message.submit(
        submit_message,
        inputs=[message, chatbot, max_tokens, temperature, top_p],
        outputs=[message, chatbot],
    )

demo.launch(debug=True)