TroglodyteDerivations commited on
Commit
9f94e87
1 Parent(s): 19e6265

Updated line 25 with: preprocess_image function -> Validates whether image is a PIL image -> if not then it converts the image utilizing Image.fromarray

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -23,10 +23,11 @@ preprocess = T.Compose([
23
 
24
  # Function to preprocess the image
25
  def preprocess_image(img_pil):
26
- img_pil = img_pil.convert('RGB') # Convert to RGB if necessary
27
- img_tensor = preprocess(img_pil)
28
- img_tensor = img_tensor.unsqueeze(0) # Add batch dimension
29
- return img_tensor
 
30
 
31
  # Function to predict the class of the image
32
  def predict_classification(img_pil):
 
23
 
24
  # Function to preprocess the image
25
  def preprocess_image(img_pil):
26
+ # Ensure img_pil is a PIL Image object
27
+ if not isinstance(img_pil, Image.Image):
28
+ img_pil = Image.fromarray(img_pil) # Convert NumPy array to PIL Image
29
+ inputs = feature_extractor(images=img_pil, return_tensors="pt")
30
+ return inputs
31
 
32
  # Function to predict the class of the image
33
  def predict_classification(img_pil):