youl commited on
Commit
834d6b9
1 Parent(s): a4da9d6

Create utils file

Browse files
Files changed (1) hide show
  1. utils.py +230 -0
utils.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from google.cloud import vision
3
+ import re
4
+ import torch
5
+ import torchvision
6
+ import numpy as np
7
+ from PIL import Image
8
+ import albumentations as A
9
+ from albumentations.pytorch import ToTensorV2
10
+ import tempfile
11
+ import json
12
+
13
+ def getcredentials():
14
+ secret_key_credential = os.getenv("secret_key")
15
+
16
+ with tempfile.NamedTemporaryFile(mode='w+', delete= True, suffix=".json") as temp_file:
17
+ temp_file.write(json.dumps(secret_key_credential))
18
+
19
+ tempfile_name = temp_file.name
20
+
21
+ return tempfile_name
22
+
23
+
24
+ os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = getcredentials()
25
+
26
+ ##
27
+ def info_new_cni(donnees):
28
+ ##
29
+ informations = {}
30
+
31
+ # Utilisation d'expressions régulières pour extraire les informations spécifiques
32
+ numero_carte = re.search(r'n° (C\d+)', ' '.join(donnees))
33
+ #prenom_nom = re.search(r'Prénom\(s\)\s+(.*?)\s+Nom\s+(.*?)\s+Signature', ' '.join(donnees))
34
+ nom = re.search(r'Nom\s+(.*?)\s', ' '.join(donnees))
35
+ prenom = re.search(r'Prénom\(s\)\s+(.*?)\s+Nom\s+(.*?)', ' '.join(donnees))
36
+ date_naissance = re.search(r'Date de Naissance\s+(.*?)+(\d{2}/\d{2}/\d{4})', ' '.join(donnees))
37
+ lieu_naissance = re.search(r'Lieu de Naissance\s+(.*?)\s', ' '.join(donnees))
38
+ taille = re.search(r'Sexe Taille\s+(.*?)+(\d+,\d+)', ' '.join(donnees))
39
+ nationalite = re.search(r'Nationalité\s+(.*?)\s+\d+', ' '.join(donnees))
40
+ date_expiration = re.search(r'Date d\'expiration\s+(\d+/\d+/\d+)', ' '.join(donnees))
41
+ sexe = re.search(r'Date de Naissance\s+(.*?)+(\d{2}/\d{2}/\d{4})+(.*)', ' '.join(donnees))
42
+
43
+ # Stockage des informations extraites dans un dictionnaire
44
+ if numero_carte:
45
+ informations['Numéro de carte'] = numero_carte.group(1)
46
+ if nom :
47
+ informations['Nom'] = nom.group(1)
48
+
49
+ if prenom:
50
+ informations['Prénom'] = prenom.group(1)
51
+
52
+ if date_naissance:
53
+ informations['Date de Naissance'] = date_naissance.group(2)
54
+ if lieu_naissance:
55
+ informations['Lieu de Naissance'] = lieu_naissance.group(1)
56
+ if taille:
57
+ informations['Taille'] = taille.group(2)
58
+ if nationalite:
59
+ informations['Nationalité'] = nationalite.group(1)
60
+ if date_expiration:
61
+ informations['Date d\'expiration'] = date_expiration.group(1)
62
+ if sexe :
63
+ informations['sexe'] = sexe.group(3)[:2]
64
+
65
+ return informations
66
+
67
+ ##
68
+
69
+ def info_ancien_cni(infos):
70
+ """ Extract information in row data of ocr"""
71
+
72
+ informations = {}
73
+
74
+ immatriculation_patern = r'Immatriculation:\s+(C \d{4} \d{4} \d{2})'
75
+ immatriculation = re.search(immatriculation_patern, ''.join(infos))
76
+ nom = infos[4]
77
+ prenom_pattern = r'Nom\n(.*?)\n'
78
+ prenom = re.search(prenom_pattern, '\n'.join(infos))
79
+ sexe_pattern = r'Prénoms\n(.*?)\n'
80
+ sexe = re.search(sexe_pattern, '\n'.join(infos))
81
+ taille_pattern = r'Sexe\n(.*?)\n'
82
+ taille = re.search(taille_pattern, '\n'.join(infos))
83
+ date_naiss_pattern = r'Taille\s+(.*?)+(\d+/\d+/\d+)' # r'Taille (m)\n(.*?)\n'
84
+ date_naissance = re.search(date_naiss_pattern, ' '.join(infos))
85
+ lieu_pattern = r'Date de Naissance\n(.*?)\n'
86
+ lieu_naissance = re.search(lieu_pattern, '\n'.join(infos))
87
+ valide_pattern = r'Valide jusqu\'au+(.*?)+(\d+/\d+/\d+)'
88
+ validite = re.search(valide_pattern, ' '.join(infos))
89
+
90
+ # Stockage des informations extraites dans un dictionnaire
91
+ if immatriculation:
92
+ informations['Immatriculation'] = immatriculation.group(1)
93
+ if nom :
94
+ informations['Nom'] = infos[4]
95
+
96
+ if prenom:
97
+ informations['Prénom'] = prenom.group(1)
98
+
99
+ if date_naissance:
100
+ informations['Date de Naissance'] = date_naissance.group(2)
101
+ if lieu_naissance:
102
+ informations['Lieu de Naissance'] = lieu_naissance.group(1)
103
+ if taille:
104
+ informations['Taille'] = taille.group(1)
105
+
106
+ if validite:
107
+ informations['Date d\'expiration'] = validite.group(2)
108
+ if sexe :
109
+ informations['sexe'] = sexe.group(1)
110
+
111
+ return informations
112
+
113
+ ##
114
+ def filtrer_elements(liste):
115
+ elements_filtres = []
116
+ for element in liste:
117
+ if element not in ['\r',"RÉPUBLIQUE DE CÔTE D'IVOIRE", "MINISTÈRE DES TRANSPORTS", "PERMIS DE CONDUIRE"]:
118
+ elements_filtres.append(element)
119
+ return elements_filtres
120
+
121
+ def permis_de_conduite(donnees):
122
+ """ Extraire les information de permis de conduire"""
123
+
124
+ informations = {}
125
+
126
+ infos = filtrer_elements(donnees)
127
+
128
+ nom_pattern = r'Nom\n(.*?)\n'
129
+ nom = re.search(nom_pattern, '\n'.join(infos))
130
+ prenom_pattern = r'Prénoms\n(.*?)\n'
131
+ prenom = re.search(prenom_pattern, '\n'.join(infos))
132
+ date_lieu_naissance_patern = r'Date et lieu de naissance\n(.*?)\n'
133
+ date_lieu_naissance = re.search(date_lieu_naissance_patern, '\n'.join(infos))
134
+ date_lieu_delivrance_patern = r'Date et lieu de délivrance\n(.*?)\n'
135
+ date_lieu_delivrance = re.search(date_lieu_delivrance_patern, '\n'.join(infos))
136
+ numero_pattern = r'Numéro du permis de conduire\n(.*?)\n'
137
+ numero = re.search(numero_pattern, '\n'.join(infos))
138
+ restriction_pattern = r'Restriction\(s\)\s+(.*?)+(.*)'
139
+ restriction = re.search(restriction_pattern, ' '.join(infos))
140
+
141
+ # Stockage des informations extraites dans un dictionnaire
142
+ if nom:
143
+ informations['Nom'] = nom.group(1)
144
+
145
+ if prenom :
146
+ informations['Prenoms'] = prenom.group(1)
147
+ if date_lieu_naissance :
148
+ informations['Date_et_lieu_de_naissance'] = date_lieu_naissance.group(1)
149
+ if date_lieu_naissance :
150
+ informations['Date_et_lieu_de_délivrance'] = date_lieu_delivrance.group(1)
151
+
152
+ informations['Categorie'] = infos[0]
153
+ if numero:
154
+ informations['Numéro_du_permis_de_conduire'] = numero.group(1)
155
+
156
+ if restriction:
157
+ informations['Restriction(s)'] = restriction.group(2)
158
+
159
+ return informations
160
+
161
+
162
+ # Fonction pour extraire les informations individuelles
163
+ def extraire_informations_carte(path, type_de_piece=1):
164
+ """ Detect text in identity card"""
165
+
166
+ client = vision.ImageAnnotatorClient()
167
+
168
+ with open(path,'rb') as image_file:
169
+ content = image_file.read()
170
+
171
+ image = vision.Image(content = content)
172
+
173
+ # for non dense text
174
+ #response = client.text_detection(image=image)
175
+ #for dense text
176
+ response = client.document_text_detection(image = image)
177
+ texts = response.text_annotations
178
+ ocr_texts = []
179
+
180
+ for text in texts:
181
+ ocr_texts.append(f"\r\n{text.description}")
182
+
183
+ if response.error.message :
184
+ raise Exception("{}\n For more informations check : https://cloud.google.com/apis/design/errors".format(response.error.message))
185
+
186
+ donnees = ocr_texts[0].split('\n')
187
+
188
+ if type_de_piece ==1:
189
+ return info_new_cni(donnees)
190
+ elif type_de_piece == 2:
191
+ return info_ancien_cni(donnees)
192
+ elif type_de_piece == 3:
193
+ return permis_de_conduite(donnees)
194
+ else :
195
+ return "Le traitement de ce type de document n'est pas encore pris en charge"
196
+
197
+ def load_checkpoint(path):
198
+ print('--> Loading checkpoint')
199
+ return torch.load(path,map_location=torch.device('cpu'))
200
+
201
+ def make_prediction(image_path):
202
+
203
+ # define the using of GPU or CPU et background training
204
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
205
+ ## load model
206
+ model = load_checkpoint("data/model.pth")
207
+ ## transformation
208
+ test_transforms = A.Compose([
209
+ A.Resize(height=224, width=224, always_apply=True),
210
+ A.Normalize(always_apply=True),
211
+ ToTensorV2(always_apply=True),])
212
+
213
+ ## read the image
214
+ image = np.array(Image.open(image_path).convert('RGB'))
215
+ transformed = test_transforms(image= image)
216
+ image_transformed = transformed["image"]
217
+ image_transformed = image_transformed.unsqueeze(0)
218
+ image_transformed = image_transformed.to(device)
219
+
220
+ model.eval()
221
+ with torch.set_grad_enabled(False):
222
+ output = model(image_transformed)
223
+
224
+ # Post-process predictions
225
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
226
+ predicted_class = torch.argmax(probabilities).item()
227
+ proba = float(max(probabilities))
228
+
229
+
230
+ return proba, predicted_class