File size: 3,529 Bytes
8baab4e
 
 
 
af8ae33
aa505fc
8baab4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a092642
 
8baab4e
 
 
 
a092642
8baab4e
 
a092642
8baab4e
 
a092642
8baab4e
aa505fc
 
2f428ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6786150
2f428ca
 
 
 
6786150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66cbf63
 
 
8baab4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b51526
8baab4e
 
 
4c173bc
8baab4e
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130

#app.py:
# from huggingface_hub import from_pretrained_fastai
import gradio as gr

from fastai import *
from fastai.data.block import DataBlock
from fastai.data.transforms import get_image_files, FuncSplitter, Normalize
from fastai.layers import Mish
from fastai.losses import BaseLoss
from fastai.optimizer import ranger
from fastai.torch_core import tensor
from fastai.vision.augment import aug_transforms
from fastai.vision.core import PILImage, PILMask
from fastai.vision.data import ImageBlock, MaskBlock, imagenet_stats
from fastai.vision.learner import unet_learner
from PIL import Image
import numpy as np
from torch import nn
import torch
import torch.nn.functional as F


# # repo_id = "YOUR_USERNAME/YOUR_LEARNER_NAME"
# repo_id = "islasher/segm-grapes"
# repo_id='islasher/segm-grapes'


# # Definimos una función que se encarga de llevar a cabo las predicciones

# from fastai.learner import load_learner

# # Cargar el modelo y el tokenizador
# learn = load_learner(repo_id)
#learner = from_pretrained_fastai(repo_id)

from huggingface_hub import from_pretrained_fastai
import torchvision.transforms as transforms
# from Transform import ItemTransform


from albumentations import (
    Compose,
    OneOf,
    ElasticTransform,
    GridDistortion,
    OpticalDistortion,
    HorizontalFlip,
    Rotate,
    Transpose,
    CLAHE,
    ShiftScaleRotate
)

class SegmentationAlbumentationsTransform(ItemTransform):
    split_idx = 0

    def __init__(self, aug):
        self.aug = aug

    def encodes(self, x):
        img,mask = x
        aug = self.aug(image=np.array(img), mask=np.array(mask))
        return PILImage.create(aug["image"]), PILMask.create(aug["mask"])

class TargetMaskConvertTransform(ItemTransform):
    def __init__(self):
        pass
    def encodes(self, x):
        img,mask = x

        #Convert to array
        mask = np.array(mask)


        # Changes: (codes= array(['Background', 'Leaves', 'Wood', 'Pole', 'Grape'], dtype='<U10'))
        mask[mask==150]=1 #leaves

        mask[mask==76]=3 #pole
        mask[mask==74]=3 #pole

        mask[mask==29]=2 #wood
        mask[mask==25]=2 #wood

        mask[mask==255]=4 #grape

        mask[mask==0]=0
        # Back to PILMask
        mask = PILMask.create(mask)
        return img, mask




learn = from_pretrained_fastai("islasher/segm-grapes")

def transform_image(image):
    my_transforms = transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image_aux = image
    return my_transforms(image_aux).unsqueeze(0).to(device)






# Definimos una función que se encarga de llevar a cabo las predicciones
def predict(img):
    image = transforms.Resize((480,640))(img)
    tensor = transform_image(image=image)
    with torch.no_grad():
        outputs = learn.model(tensor)
    
    outputs = torch.argmax(outputs,1)

    mask = np.array(outputs)
    mask[mask==1]=150
    mask[mask==3]=76 #pole # y no 74
# mask[mask==5]=74 #pole
    mask[mask==2]=29 #wood # y no 25
# mask[mask==6]=25 #wood
    mask[mask==4]=255 #grape
    mask=np.reshape(mask,(480,640)) #en modo matriz
    return Image.fromarray(mask.astype('uint8'))
    
# Creamos la interfaz y la lanzamos. 
gr.Interface(fn=predict, inputs=gr.inputs.Image(shape=(128, 128)), outputs=gr.outputs.Image(shape=(480,640)),examples=['color_154.jpg','color_155.jpg']).launch(share=False)