from transformers import AutoModel, AutoTokenizer import torch import json import requests from PIL import Image from torchvision import transforms import urllib.request from torchvision import models import torch.nn as nn schema ={ "inputs": [ { "name": "image", "type": "image", "description": "The image file to classify." }, { "name": "title", "type": "string", "description": "The text title associated with the image." } ], "outputs": [ { "name": "label", "type": "string", "description": "Predicted class label." }, { "name": "probability", "type": "float", "description": "Prediction confidence score." } ] } # --- Define the Model --- class FineGrainedClassifier(nn.Module): def __init__(self, num_classes=434): # Updated to 434 classes super(FineGrainedClassifier, self).__init__() self.image_encoder = models.resnet50(pretrained=True) self.image_encoder.fc = nn.Identity() self.text_encoder = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en') self.classifier = nn.Sequential( nn.Linear(2048 + 768, 1024), nn.BatchNorm1d(1024), nn.ReLU(), nn.Dropout(0.3), nn.Linear(1024, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3), nn.Linear(512, num_classes) # Updated to 434 classes ) def forward(self, image, input_ids, attention_mask): image_features = self.image_encoder(image) text_output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask) text_features = text_output.last_hidden_state[:, 0, :] combined_features = torch.cat((image_features, text_features), dim=1) output = self.classifier(combined_features) return output # --- Data Augmentation Setup --- transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # # Load the label-to-class mapping from your Hugging Face repository # label_map_url = "https://huggingface.co/Maverick98/EcommerceClassifier/resolve/main/label_to_class.json" # label_to_class = requests.get(label_map_url).json() # Load your custom model from Hugging Face model = FineGrainedClassifier(num_classes=len(label_to_class)) checkpoint_url = f"https://huggingface.co/Maverick98/EcommerceClassifier/resolve/main/model_checkpoint.pth" checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, map_location=torch.device('cpu')) # Strip the "module." prefix from the keys in the state_dict if they exist # Clean up the state dictionary state_dict = checkpoint.get('model_state_dict', checkpoint) new_state_dict = {} for k, v in state_dict.items(): if k.startswith("module."): new_key = k[7:] # Remove "module." prefix else: new_key = k # Check if the new_key exists in the model's state_dict, only add if it does if new_key in model.state_dict(): new_state_dict[new_key] = v model.load_state_dict(new_state_dict) # Load the tokenizer from Jina tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en") # def load_image(image_path_or_url): # if isinstance(image_path_or_url, str) and image_path_or_url.startswith("http"): # with urllib.request.urlopen(image_path_or_url) as url: # image = Image.open(url).convert('RGB') # else: # image = Image.open(image_path_or_url).convert('RGB') # image = transform(image) # image = image.unsqueeze(0) # Add batch dimension # return image # def predict(image_path_or_file, title, threshold=0.4): def inference(inputs): image = inputs.get("image") title = inputs.get("title") if not isinstance(title, str): return {"error": "Title must be a string."} if not isinstance(image, (Image.Image, torch.Tensor)): return {"error": "Image must be a valid image file or a tensor."} threshold = 0.4 # Validation: Check if the title is empty or has fewer than 3 words if not title or len(title.split()) < 3: raise gr.Error("Title must be at least 3 words long. Please provide a valid title.") # Preprocess the image image = load_image(image_path_or_file) # Tokenize title title_encoding = tokenizer(title, padding='max_length', max_length=200, truncation=True, return_tensors='pt') input_ids = title_encoding['input_ids'] attention_mask = title_encoding['attention_mask'] # Predict model.eval() with torch.no_grad(): output = model(image, input_ids=input_ids, attention_mask=attention_mask) probabilities = torch.nn.functional.softmax(output, dim=1) top3_probabilities, top3_indices = torch.topk(probabilities, 3, dim=1) # Map indices to class names (Assuming you have a mapping) with open("label_to_class.json", "r") as f: label_to_class = json.load(f) # Map the top 3 indices to class names top3_classes = [label_to_class[str(idx.item())] for idx in top3_indices[0]] # Check if the highest probability is below the threshold if top3_probabilities[0][0].item() < threshold: top3_classes.insert(0, "Others") top3_probabilities = torch.cat((torch.tensor([[1.0 - top3_probabilities[0][0].item()]]), top3_probabilities), dim=1) # Prepare the output as a dictionary results = {} for i in range(len(top3_classes)): results[top3_classes[i]] = top3_probabilities[0][i].item() return results