TroglodyteDerivations commited on
Commit
8a67a54
1 Parent(s): ce69b66

Updated lines 27 , 37, and 38 with: error handling -> streamlining meaning feedback if an error is logged obverse model fomenting Pokémon predictions.

Browse files
Files changed (1) hide show
  1. app.py +13 -9
app.py CHANGED
@@ -24,15 +24,19 @@ def preprocess_image(img_pil):
24
 
25
  # Function to predict the class of the image
26
  def predict_classification(img_pil):
27
- inputs = preprocess_image(img_pil)
28
- pixel_values = inputs['pixel_values'][0] # Extract the pixel_values tensor
29
- with torch.no_grad():
30
- outputs = model(pixel_values=pixel_values.unsqueeze(0)) # Add batch dimension
31
- logits = outputs.logits
32
- predicted_class_idx = logits.argmax(-1).item()
33
- predicted_class = labels[str(predicted_class_idx)]
34
- confidence = torch.nn.functional.softmax(logits, dim=1).numpy()[0][predicted_class_idx]
35
- return predicted_class, confidence
 
 
 
 
36
 
37
  # Function to handle the prediction in the Gradio interface
38
  def pokémon_predict(img_pil):
 
24
 
25
  # Function to predict the class of the image
26
  def predict_classification(img_pil):
27
+ try:
28
+ inputs = preprocess_image(img_pil)
29
+ pixel_values = inputs['pixel_values'][0] # Extract the pixel_values tensor
30
+ with torch.no_grad():
31
+ outputs = model(pixel_values=pixel_values.unsqueeze(0)) # Add batch dimension
32
+ logits = outputs.logits
33
+ predicted_class_idx = logits.argmax(-1).item()
34
+ predicted_class = labels[str(predicted_class_idx)]
35
+ confidence = torch.nn.functional.softmax(logits, dim=1).numpy()[0][predicted_class_idx]
36
+ return predicted_class, confidence
37
+ except Exception as e:
38
+ return "Error occurred during prediction", 0.0
39
+
40
 
41
  # Function to handle the prediction in the Gradio interface
42
  def pokémon_predict(img_pil):