TroglodyteDerivations
commited on
Commit
•
8a67a54
1
Parent(s):
ce69b66
Updated lines 27 , 37, and 38 with: error handling -> streamlining meaning feedback if an error is logged obverse model fomenting Pokémon predictions.
Browse files
app.py
CHANGED
@@ -24,15 +24,19 @@ def preprocess_image(img_pil):
|
|
24 |
|
25 |
# Function to predict the class of the image
|
26 |
def predict_classification(img_pil):
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
36 |
|
37 |
# Function to handle the prediction in the Gradio interface
|
38 |
def pokémon_predict(img_pil):
|
|
|
24 |
|
25 |
# Function to predict the class of the image
|
26 |
def predict_classification(img_pil):
|
27 |
+
try:
|
28 |
+
inputs = preprocess_image(img_pil)
|
29 |
+
pixel_values = inputs['pixel_values'][0] # Extract the pixel_values tensor
|
30 |
+
with torch.no_grad():
|
31 |
+
outputs = model(pixel_values=pixel_values.unsqueeze(0)) # Add batch dimension
|
32 |
+
logits = outputs.logits
|
33 |
+
predicted_class_idx = logits.argmax(-1).item()
|
34 |
+
predicted_class = labels[str(predicted_class_idx)]
|
35 |
+
confidence = torch.nn.functional.softmax(logits, dim=1).numpy()[0][predicted_class_idx]
|
36 |
+
return predicted_class, confidence
|
37 |
+
except Exception as e:
|
38 |
+
return "Error occurred during prediction", 0.0
|
39 |
+
|
40 |
|
41 |
# Function to handle the prediction in the Gradio interface
|
42 |
def pokémon_predict(img_pil):
|