bgaspra commited on
Commit
26d55ba
1 Parent(s): 107b2a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -13
app.py CHANGED
@@ -21,6 +21,9 @@ model = AutoModelForCausalLM.from_pretrained(
21
  ).to(device)
22
  processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
23
 
 
 
 
24
  # Load CivitAI dataset
25
  print("Loading dataset...")
26
  dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k", split="train[:1000]")
@@ -31,18 +34,28 @@ text_embedding_cache = {}
31
 
32
  def get_image_embedding(image):
33
  try:
34
- # Process image and add dummy text input
35
  inputs = processor(
36
  images=image,
37
- text="Describe this image", # Adding a default text prompt
38
- padding=True,
39
- return_tensors="pt"
40
  ).to(device, torch_dtype)
41
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  with torch.no_grad():
43
- # Get model outputs
44
  outputs = model(**inputs)
45
- # Extract image features from the cross-attention layers
46
  image_embeddings = outputs.last_hidden_state.mean(dim=1)
47
  return image_embeddings.cpu().numpy()
48
  except Exception as e:
@@ -54,22 +67,26 @@ def get_text_embedding(text):
54
  if text in text_embedding_cache:
55
  return text_embedding_cache[text]
56
 
57
- # Process text with proper input formatting
58
  inputs = processor(
 
59
  text=text,
60
- padding=True,
61
- return_tensors="pt"
62
  ).to(device, torch_dtype)
63
 
64
- # Add required decoder input ids
65
- inputs['decoder_input_ids'] = model.generate(
66
  **inputs,
67
  max_length=1,
 
 
 
68
  return_dict_in_generate=True,
69
- output_hidden_states=True,
70
- early_stopping=True
71
  ).sequences
72
 
 
 
73
  with torch.no_grad():
74
  outputs = model(**inputs)
75
  text_embeddings = outputs.last_hidden_state.mean(dim=1)
@@ -134,6 +151,9 @@ def process_image(input_image):
134
  if not isinstance(input_image, Image.Image):
135
  input_image = Image.fromarray(input_image)
136
 
 
 
 
137
  recommended_models, recommended_prompts = find_similar_images(input_image)
138
 
139
  if not recommended_models or not recommended_prompts:
 
21
  ).to(device)
22
  processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
23
 
24
+ # Create a dummy image for text-only processing
25
+ DUMMY_IMAGE = Image.new('RGB', (224, 224), color='white')
26
+
27
  # Load CivitAI dataset
28
  print("Loading dataset...")
29
  dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k", split="train[:1000]")
 
34
 
35
  def get_image_embedding(image):
36
  try:
 
37
  inputs = processor(
38
  images=image,
39
+ text="Generate image description",
40
+ return_tensors="pt",
41
+ padding=True
42
  ).to(device, torch_dtype)
43
 
44
+ # Generate decoder_input_ids
45
+ decoder_input_ids = model.generate(
46
+ **inputs,
47
+ max_length=1,
48
+ min_length=1,
49
+ num_beams=1,
50
+ pad_token_id=processor.tokenizer.pad_token_id,
51
+ return_dict_in_generate=True,
52
+ ).sequences
53
+
54
+ inputs['decoder_input_ids'] = decoder_input_ids
55
+
56
  with torch.no_grad():
 
57
  outputs = model(**inputs)
58
+ # Use the mean of the last hidden state as the embedding
59
  image_embeddings = outputs.last_hidden_state.mean(dim=1)
60
  return image_embeddings.cpu().numpy()
61
  except Exception as e:
 
67
  if text in text_embedding_cache:
68
  return text_embedding_cache[text]
69
 
70
+ # Process text with dummy image
71
  inputs = processor(
72
+ images=DUMMY_IMAGE,
73
  text=text,
74
+ return_tensors="pt",
75
+ padding=True
76
  ).to(device, torch_dtype)
77
 
78
+ # Generate decoder_input_ids
79
+ decoder_input_ids = model.generate(
80
  **inputs,
81
  max_length=1,
82
+ min_length=1,
83
+ num_beams=1,
84
+ pad_token_id=processor.tokenizer.pad_token_id,
85
  return_dict_in_generate=True,
 
 
86
  ).sequences
87
 
88
+ inputs['decoder_input_ids'] = decoder_input_ids
89
+
90
  with torch.no_grad():
91
  outputs = model(**inputs)
92
  text_embeddings = outputs.last_hidden_state.mean(dim=1)
 
151
  if not isinstance(input_image, Image.Image):
152
  input_image = Image.fromarray(input_image)
153
 
154
+ # Resize image to expected size
155
+ input_image = input_image.resize((224, 224))
156
+
157
  recommended_models, recommended_prompts = find_similar_images(input_image)
158
 
159
  if not recommended_models or not recommended_prompts: