gabrielchua's picture
Update app.py
26ad4ba verified
raw
history blame
No virus
2.7 kB
import spaces
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
import json
import requests
from openai import OpenAI
HF_API_KEY = os.getenv("HF_API_KEY")
MODEL_ID = "meta-llama/Llama-Guard-3-1B"
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
token=HF_API_KEY
).to('cuda')
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_API_KEY).to('cuda')
# Model 1: Llama Guard Model
@spaces.GPU
def llama_guard_moderation(input_text):
conversation = [
{
"role": "user",
"content": [
{
"type": "text",
"text": input_text
},
],
}
]
input_ids = tokenizer.apply_chat_template(
conversation, return_tensors="pt"
).to(model.device)
prompt_len = input_ids.shape[1]
output = model.generate(
input_ids,
max_new_tokens=20,
pad_token_id=0,
)
generated_tokens = output[:, prompt_len:]
return tokenizer.decode(generated_tokens[0])
# Model 2: OpenAI Omni Moderation
def openai_moderation(input_text):
client = OpenAI()
response = client.moderations.create(
model="omni-moderation-latest",
input=input_text,
)
return response.results[0].categories
# Model 3: Sentinel API for LionGuard
def sentinel_moderation(input_text):
api_key = os.getenv("SENTINEL_API_KEY")
api_endpoint = os.getenv("SENTINEL_ENDPOINT")
headers = {
"x-api-key": api_key,
"Content-Type": "application/json"
}
payload = {
"filters": ["lionguard"],
"text": input_text
}
response = requests.post(
url=api_endpoint,
headers=headers,
data=json.dumps(payload)
)
return response.json()["outputs"]["lionguard"]
# Gradio App
def moderate_text(input_text):
llama_guard_result = llama_guard_moderation(input_text)
openai_result = openai_moderation(input_text)
sentinel_result = sentinel_moderation(input_text)
return llama_guard_result, openai_result, sentinel_result
iface = gr.Interface(
fn=moderate_text,
inputs=gr.Textbox(lines=5, label="Enter Text for Moderation"),
outputs=[
gr.Textbox(label="Llama Guard Result"),
gr.Textbox(label="OpenAI Omni Moderation Result"),
gr.Textbox(label="Sentinel LionGuard Result"),
],
title="Content Moderation Model Comparison",
description="Compare the performance of 3 content moderation models: Llama Guard, OpenAI Omni Moderation, and Sentinel LionGuard."
)
iface.launch()