import torch import gradio as gr from transformers import AutoProcessor, AutoModelForCausalLM from PIL import Image import pandas as pd from datasets import load_dataset from sklearn.metrics.pairwise import cosine_similarity import numpy as np import warnings warnings.filterwarnings('ignore') # Load Florence-2 model and processor model_name = "microsoft/Florence-2-base" device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch_dtype, trust_remote_code=True ).to(device) processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) # Create a dummy image for text-only processing DUMMY_IMAGE = Image.new('RGB', (224, 224), color='white') # Load CivitAI dataset print("Loading dataset...") dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k", split="train[:1000]") df = pd.DataFrame(dataset) print("Dataset loaded successfully!") text_embedding_cache = {} def get_image_embedding(image): try: inputs = processor( images=image, text="Generate image description", return_tensors="pt", padding=True ).to(device, torch_dtype) # Generate decoder_input_ids with adjusted parameters decoder_input_ids = model.generate( **inputs, max_new_tokens=20, # Increased from max_length min_length=1, num_beams=1, do_sample=False, pad_token_id=processor.tokenizer.pad_token_id, return_dict_in_generate=True, ).sequences inputs['decoder_input_ids'] = decoder_input_ids with torch.no_grad(): outputs = model(**inputs) image_embeddings = outputs.last_hidden_state.mean(dim=1) return image_embeddings.cpu().numpy() except Exception as e: print(f"Error in get_image_embedding: {str(e)}") return None def get_text_embedding(text): try: if text in text_embedding_cache: return text_embedding_cache[text] # Process text with dummy image inputs = processor( images=DUMMY_IMAGE, text=text, return_tensors="pt", padding=True ).to(device, torch_dtype) # Generate decoder_input_ids with adjusted parameters decoder_input_ids = model.generate( **inputs, max_new_tokens=20, # Using max_new_tokens instead of max_length min_length=1, num_beams=1, do_sample=False, pad_token_id=processor.tokenizer.pad_token_id, return_dict_in_generate=True, ).sequences inputs['decoder_input_ids'] = decoder_input_ids with torch.no_grad(): outputs = model(**inputs) text_embeddings = outputs.last_hidden_state.mean(dim=1) embedding = text_embeddings.cpu().numpy() text_embedding_cache[text] = embedding return embedding except Exception as e: print(f"Error in get_text_embedding: {str(e)}") return None def precompute_embeddings(): print("Pre-computing text embeddings...") for idx, row in df.iterrows(): if row['prompt'] not in text_embedding_cache: _ = get_text_embedding(row['prompt']) if idx % 100 == 0: print(f"Processed {idx}/1000 embeddings") print("Finished pre-computing embeddings") def find_similar_images(uploaded_image, top_k=5): query_embedding = get_image_embedding(uploaded_image) if query_embedding is None: return [], [] similarities = [] for idx, row in df.iterrows(): prompt_embedding = get_text_embedding(row['prompt']) if prompt_embedding is not None: similarity = cosine_similarity(query_embedding, prompt_embedding)[0][0] similarities.append({ 'similarity': similarity, 'model': row['Model'], 'prompt': row['prompt'] }) sorted_results = sorted(similarities, key=lambda x: x['similarity'], reverse=True) top_models = [] top_prompts = [] seen_models = set() seen_prompts = set() for result in sorted_results: if len(top_models) < top_k and result['model'] not in seen_models: top_models.append(result['model']) seen_models.add(result['model']) if len(top_prompts) < top_k and result['prompt'] not in seen_prompts: top_prompts.append(result['prompt']) seen_prompts.add(result['prompt']) if len(top_models) == top_k and len(top_prompts) == top_k: break return top_models, top_prompts def process_image(input_image): if input_image is None: return "Please upload an image.", "Please upload an image." try: if not isinstance(input_image, Image.Image): input_image = Image.fromarray(input_image) # Resize image to expected size input_image = input_image.resize((224, 224)) recommended_models, recommended_prompts = find_similar_images(input_image) if not recommended_models or not recommended_prompts: return "Error processing image.", "Error processing image." models_text = "Recommended Models:\n" + "\n".join([f"{i+1}. {model}" for i, model in enumerate(recommended_models)]) prompts_text = "Recommended Prompts:\n" + "\n".join([f"{i+1}. {prompt}" for i, prompt in enumerate(recommended_prompts)]) return models_text, prompts_text except Exception as e: print(f"Error in process_image: {str(e)}") return "Error processing image.", "Error processing image." # Pre-compute embeddings when starting the application try: precompute_embeddings() except Exception as e: print(f"Error in precompute_embeddings: {str(e)}") # Create Gradio interface iface = gr.Interface( fn=process_image, inputs=gr.Image(type="pil", label="Upload AI-generated image"), outputs=[ gr.Textbox(label="Recommended Models", lines=6), gr.Textbox(label="Recommended Prompts", lines=6) ], title="AI Image Model & Prompt Recommender", description="Upload an AI-generated image to get recommendations for Stable Diffusion models and prompts.", examples=[], cache_examples=False ) # Launch the interface iface.launch()