kingabzpro's picture
Update app.py
40c4a66 verified
raw
history blame
1.76 kB
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
import torch
MODEL_URL = "kingabzpro/Llama-3.1-8B-Instruct-Mental-Health-Classification"
tokenizer = AutoTokenizer.from_pretrained(MODEL_URL)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_URL, low_cpu_mem_usage=True, return_dict=True,torch_dtype=torch.float16,
device_map="cpu")
def prediction(news):
# create pipeline
clasifer = pipeline("text-generation", tokenizer=tokenizer, model=model, torch_dtype=torch.float16,
device_map="cpu",)
outputs = pipe(prompt, max_new_tokens=2, do_sample=True, temperature=0.1)
preds = outputs[0]["generated_text"].split("label: ")[-1].strip()
return preds
gradio_ui = gr.Interface(
fn=prediction,
title="Mental Health Disorder Classification",
description=f"Input the text to generate a Mental Health Disorder.\n For this classification, the {MODEL_URL} model was used.",
examples=[
['trouble sleeping, confused mind, restless heart. All out of tune'],
["In the quiet hours, even the shadows seem too heavy to bear."],
["Riding a tempest of emotions, where ecstatic highs crash into desolate lows without warning."]
],
inputs=gr.inputs.Textbox(lines=10, label="Write the text here"),
outputs=gr.outputs.Label(num_top_classes=5, type="auto", label="Mental Health Disorder Category"),
theme="huggingface",
article="<p style='text-align: center'>Please read the tutorial to fine-tune the Llama 3.1 model on Mental Health Classification <a href='https://www.datacamp.com/tutorial/fine-tuning-llama-3-1' target='_blank'>https://www.datacamp.com/tutorial/fine-tuning-llama-3-1</a></p>",
)
gradio_ui.launch()