|
import gradio as gr |
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
|
import torch |
|
|
|
MODEL_URL = "kingabzpro/Llama-3.1-8B-Instruct-Mental-Health-Classification" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_URL) |
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(MODEL_URL, |
|
low_cpu_mem_usage=True, |
|
return_dict=True, |
|
torch_dtype=torch.float16, |
|
device_map="cpu") |
|
|
|
def prediction(text): |
|
|
|
pipe = pipeline("text-generation", tokenizer=tokenizer, model=model, torch_dtype=torch.float16, |
|
device_map="cpu",) |
|
|
|
prompt = f"""Classify the text into Normal, Depression, Anxiety, Bipolar, and return the answer as the corresponding mental health disorder label. |
|
text: {text} |
|
label: """.strip() |
|
outputs = pipe(prompt, max_new_tokens=2, do_sample=True, temperature=0.1) |
|
preds = outputs[0]["generated_text"].split("label: ")[-1].strip() |
|
|
|
return preds |
|
|
|
|
|
gradio_ui = gr.Interface( |
|
fn=prediction, |
|
title="Mental Health Disorder Classification", |
|
description=f"Input the text to generate a Mental Health Disorder.\n For this classification, the {MODEL_URL} model was used.", |
|
examples=[ |
|
['trouble sleeping, confused mind, restless heart. All out of tune'], |
|
["In the quiet hours, even the shadows seem too heavy to bear."], |
|
["Riding a tempest of emotions, where ecstatic highs crash into desolate lows without warning."] |
|
], |
|
inputs=gr.Textbox(lines=10, label="Write the text here"), |
|
outputs=gr.Label(num_top_classes=4, label="Mental Health Disorder Category"), |
|
theme= gr.themes.Soft(), |
|
article="<p style='text-align: center'>Please read the tutorial to fine-tune the Llama 3.1 model on Mental Health Classification <a href='https://www.datacamp.com/tutorial/fine-tuning-llama-3-1' target='_blank'>https://www.datacamp.com/tutorial/fine-tuning-llama-3-1</a></p>", |
|
) |
|
|
|
gradio_ui.launch() |