Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer | |
import torch | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
logger.info("Starting the script") | |
# Load model and tokenizer | |
model_name = "peterkros/immunization-classification-model" | |
try: | |
logger.info(f"Loading model from {model_name}") | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
logger.info("Model and tokenizer loaded successfully") | |
except Exception as e: | |
logger.error(f"Error loading model and tokenizer: {e}") | |
raise e | |
# Define the pipeline | |
try: | |
logger.info("Setting up the pipeline") | |
classifier = pipeline("text-classification", model=model, tokenizer=tokenizer) | |
logger.info("Pipeline set up successfully") | |
except Exception as e: | |
logger.error(f"Error setting up the pipeline: {e}") | |
raise e | |
def classify_text(text): | |
try: | |
logger.info(f"Classifying text: {text}") | |
predictions = classifier(text) | |
logger.info(f"Predictions: {predictions}") | |
# Process predictions to add the custom logic | |
result = [] | |
for prediction in predictions: | |
if prediction['score'] > 0.92: | |
label = "Immunization" | |
else: | |
label = "None" | |
result.append({'label': label, 'score': prediction['score']}) | |
logger.info(f"Processed predictions: {result}") | |
return result | |
except Exception as e: | |
logger.error(f"Error classifying text: {e}") | |
return {"error": str(e)} | |
# Create Gradio interface | |
try: | |
logger.info("Setting up Gradio interface") | |
iface = gr.Interface( | |
fn=classify_text, | |
inputs=gr.Textbox(lines=2, placeholder="Enter text here..."), | |
outputs=gr.JSON(), | |
title="Text Classification with DistilBERT", | |
description="Enter text to classify it using a DistilBERT model trained for text classification." | |
) | |
logger.info("Gradio interface set up successfully") | |
except Exception as e: | |
logger.error(f"Error setting up Gradio interface: {e}") | |
raise e | |
# Launch the app | |
if __name__ == "__main__": | |
try: | |
logger.info("Launching Gradio interface") | |
iface.launch() | |
logger.info("Gradio interface launched successfully") | |
except Exception as e: | |
logger.error(f"Error launching Gradio interface: {e}") | |
raise e |