Rubén Escobedo commited on
Commit
55351f1
1 Parent(s): 41c4682

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -72
app.py CHANGED
@@ -3,51 +3,7 @@ import gradio as gr
3
  import torchvision.transforms as transforms
4
  import torch
5
 
6
- # Definimos todo lo necesario para hacer inferencia
7
- from albumentations import (
8
- Compose,
9
- OneOf,
10
- ElasticTransform,
11
- GridDistortion,
12
- OpticalDistortion,
13
- HorizontalFlip,
14
- Transpose,
15
- CLAHE,
16
- ShiftScaleRotate
17
- )
18
-
19
- class SegmentationAlbumentationsTransform(ItemTransform):
20
- split_idx = 0
21
-
22
- def __init__(self, aug):
23
- self.aug = aug
24
-
25
- def encodes(self, x):
26
- img,mask = x
27
- aug = self.aug(image=np.array(img), mask=np.array(mask))
28
- return PILImage.create(aug["image"]), PILMask.create(aug["mask"])
29
-
30
- class TargetMaskConvertTransform(ItemTransform):
31
- def __init__(self):
32
- pass
33
- def encodes(self, x):
34
- img,mask = x
35
-
36
- #Convert to array
37
- mask = np.array(mask)
38
-
39
- # background = 0, leaves = 1, pole = 74 o 76, wood = 25 o 29, grape = 255
40
- mask[mask == 255] = 1 # grape
41
- mask[mask == 150] = 2 # leaves
42
- mask[mask == 76] = 3 ; mask[mask == 74] = 3 # pole
43
- mask[mask == 29] = 4 ; mask[mask == 25] = 4 # wood
44
- mask[mask >= 5] = 0 # resto: background
45
-
46
- # Back to PILMask
47
- mask = PILMask.create(mask)
48
- return img, mask
49
-
50
- def transform_image(image, device):
51
  my_transforms = transforms.Compose([transforms.ToTensor(),
52
  transforms.Normalize(
53
  [0.485, 0.456, 0.406],
@@ -55,35 +11,29 @@ def transform_image(image, device):
55
  image_aux = image
56
  return my_transforms(image_aux).unsqueeze(0).to(device)
57
 
58
- def mask_to_img(mask):
59
- mask[mask == 1] = 255 # grape
60
- mask[mask == 2] = 150 # leaves
61
- mask[mask == 3] = 74 # pole
62
- mask[mask == 4] = 25 # wood
63
- mask=np.reshape(mask,(480,640))
64
-
65
- return mask
66
-
67
  # Definimos una función que se encarga de llevar a cabo las predicciones
68
  def predict(img):
69
- learn = load_learner('export.pkl')
70
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
- model = learn.cpu()
72
- model.eval()
73
-
74
- image = transforms.Resize((480,640))(img)
75
- tensor = transform_image(image, device)
76
-
77
- model.to(device)
78
- with torch.no_grad():
79
- outputs = model(tensor)
80
-
81
- outputs = torch.argmax(outputs,1)
82
-
83
- mask = np.array(outputs.cpu())
84
- mask = mask_to_img(mask)
85
-
86
- return Image.fromarray(mask.astype('uint8'))
 
 
 
87
 
88
  # Creamos la interfaz y la lanzamos.
89
  gr.Interface(fn=predict, inputs=gr.inputs.Image(shape=(128, 128)), outputs=gr.outputs.Image(),examples=['color_154.jpg','color_155.jpg']).launch(share=False)
 
3
  import torchvision.transforms as transforms
4
  import torch
5
 
6
+ def transform_image(image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  my_transforms = transforms.Compose([transforms.ToTensor(),
8
  transforms.Normalize(
9
  [0.485, 0.456, 0.406],
 
11
  image_aux = image
12
  return my_transforms(image_aux).unsqueeze(0).to(device)
13
 
 
 
 
 
 
 
 
 
 
14
  # Definimos una función que se encarga de llevar a cabo las predicciones
15
  def predict(img):
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ model = torch.jit.load("model.pth")
18
+ model = model.cpu()
19
+ model.eval()
20
+
21
+ image = transforms.Resize((480,640))(img)
22
+ tensor = transform_image(image=image)
23
+
24
+ model.to(device)
25
+ with torch.no_grad():
26
+ outputs = model(tensor)
27
+
28
+ mask = np.array(outputs.cpu())
29
+ mask[mask == 1] = 255 # grape
30
+ mask[mask == 2] = 150 # leaves
31
+ mask[mask == 3] = 76 # pole
32
+ mask[mask == 4] = 29 # wood
33
+
34
+ mask=np.reshape(mask,(480,640))
35
+
36
+ return Image.fromarray(mask.astype('uint8'))
37
 
38
  # Creamos la interfaz y la lanzamos.
39
  gr.Interface(fn=predict, inputs=gr.inputs.Image(shape=(128, 128)), outputs=gr.outputs.Image(),examples=['color_154.jpg','color_155.jpg']).launch(share=False)