islasher commited on
Commit
8baab4e
verified
1 Parent(s): 539b88e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #app.py:
3
+ # from huggingface_hub import from_pretrained_fastai
4
+ import gradio as gr
5
+ from fastcore.xtras import Path
6
+ from fastai.callback.hook import summary
7
+ from fastai.callback.progress import ProgressCallback
8
+ from fastai.callback.schedule import lr_find, fit_flat_cos
9
+ from fastai.data.block import DataBlock
10
+ from fastai.data.external import untar_data, URLs
11
+ from fastai.data.transforms import get_image_files, FuncSplitter, Normalize
12
+ from fastai.layers import Mish
13
+ from fastai.losses import BaseLoss
14
+ from fastai.optimizer import ranger
15
+ from fastai.torch_core import tensor
16
+ from fastai.vision.augment import aug_transforms
17
+ from fastai.vision.core import PILImage, PILMask
18
+ from fastai.vision.data import ImageBlock, MaskBlock, imagenet_stats
19
+ from fastai.vision.learner import unet_learner
20
+ from PIL import Image
21
+ import numpy as np
22
+ from torch import nn
23
+ from torchvision.models.resnet import resnet34
24
+ import torch
25
+ import torch.nn.functional as F
26
+
27
+
28
+ # # repo_id = "YOUR_USERNAME/YOUR_LEARNER_NAME"
29
+ repo_id = "islasher/segm-grapes"
30
+
31
+
32
+ # # Definimos una funci贸n que se encarga de llevar a cabo las predicciones
33
+
34
+
35
+ # # Cargar el modelo y el tokenizador
36
+ learn = load_learner(repo_id)
37
+ #learner = from_pretrained_fastai(repo_id)
38
+
39
+ import torchvision.transforms as transforms
40
+ def transform_image(image):
41
+ my_transforms = transforms.Compose([transforms.ToTensor(),
42
+ transforms.Normalize(
43
+ [0.485, 0.456, 0.406],
44
+ [0.229, 0.224, 0.225])])
45
+ image_aux = image
46
+ return my_transforms(image_aux).unsqueeze(0).to(device)
47
+
48
+
49
+
50
+
51
+
52
+
53
+ # Definimos una funci贸n que se encarga de llevar a cabo las predicciones
54
+ def predict(img):
55
+ image = transforms.Resize((480,640))(img)
56
+ tensor = transform_image(image=image)
57
+ with torch.no_grad():
58
+ outputs = model(tensor)
59
+
60
+ outputs = torch.argmax(outputs,1)
61
+
62
+ mask = np.array(outputs.cpu())
63
+ mask[mask==1]=150
64
+ mask[mask==3]=76 #pole # y no 74
65
+ # mask[mask==5]=74 #pole
66
+ mask[mask==2]=29 #wood # y no 25
67
+ # mask[mask==6]=25 #wood
68
+ mask[mask==4]=255 #grape
69
+ mask=np.reshape(mask,(480,640)) #en modo matriz
70
+ return Image.fromarray(mask.astype('uint8'))
71
+
72
+ # Creamos la interfaz y la lanzamos.
73
+ gr.Interface(fn=predict, inputs=gr.inputs.Image(shape=(128, 128)), outputs=gr.outputs.Image(shape=(480,640)),examples=['color_154.jpg','color_155.jpg']).launch(share=False)