from keras.models import load_model from PIL import Image, ImageOps import numpy as np import gradio as gr import pandas as pd import json import os import glob # === READ AND LOAD FILES === folder = '.' data = pd.read_csv(os.path.join(folder, 'species_info.csv')) with open(os.path.join(folder, 'translation.json'), 'r') as f: translation = json.load(f) # Load the model model = load_model(os.path.join(folder, 'keras_model.h5')) # Load label file with open(os.path.join(folder, 'labels.txt'),'r') as f: labels = f.readlines() # === GLOBAL VARIABLES === language = '' article = "" def format_label(label): """ From '0 rùa khác\n' to 'rùa khác' """ try: int(label.split(' ')[0]) return label[label.find(" ")+1:-1] except: return label[:-1] def get_name(scientific_name, lan): """ Return name in Vietnamese """ return data[data[f'scientific_name'] == scientific_name][f'name_{lan}'].to_list()[0] def get_fun_fact(scientific_name, lan): """ Return fun fact of the species """ return data[data[f'scientific_name'] == scientific_name][f'fun_fact_{lan}'].to_list()[0] def get_law(scientific_name): cites = data[data['scientific_name'] == scientific_name]['CITES'].to_list()[0] nd06 = data[data['scientific_name'] == scientific_name]['ND06'].to_list()[0] return cites, nd06 def get_habitat(scientific_name, lan): return data[data['scientific_name'] == scientific_name][f'habitat_{lan}'].to_list()[0] def get_conservation_status(scientific_name, lan): status_list = ['NE', 'DD', 'LC', 'NT', 'VU', 'EN', 'CR', 'EW', 'EX'] status = data[data['scientific_name'] == scientific_name]['IUCN'].to_list()[0] for s in status_list: if s in status: return translation['conservation_status'][s][lan] def get_language_code(lan): global language if lan == "Tiếng Việt": language = 'vi' if lan == "English": language = 'en' return language def get_species_list(): """ Example: ['Indotestudo elongata', 'Cuora galbinifrons', 'Cuora mouhotii', 'Cuora bourreti'] """ return [format_label(s) for s in labels] def get_species_abbreviation(scientific_name): return "".join([s[0] for s in scientific_name.split()]) def get_species_abbreviation_list(): """ Example: ['Ie', 'Cg', 'Cm', 'Cb'] """ return [get_species_abbreviation(s) for s in get_species_list()] def get_description(language): num_class = len(labels) num_native = 0 num_non_native = 0 native_list = '' non_native_list = '' for i in labels: label = format_label(i) if label in data[data.native == 'y'].scientific_name.values: num_native += 1 native_list += f"({num_native}) {get_name(label, language)}, " else: num_non_native += 1 non_native_list += f"({num_non_native}) {get_name(label, language)}, " if language=='vi': description=f""" VNTurtle nhận diện các loài rùa Việt Nam. Mô hình này có thể nhận diện **{num_class}** loại rùa thường xuất hiện ở VN gồm - **{num_native}** loài bản địa: {native_list} \n\n - **{num_non_native}** loài ngoại lai: {non_native_list} """ if language=='en': description=f""" VNTurtle can recognize turtle species in Vietnam. This model can identify {num_class} common turtles in Vietnam including **{num_native}** native species \n\n {native_list} \n\n and **{num_non_native}** non-native species \n\n {non_native_list} """ return description def update_language(language): language = get_language_code(language) return get_description(language), \ translation['label']['label_run_btn'][language], \ translation["accordion"]["fun_fact"][language], \ translation["accordion"]["status"][language], \ translation["accordion"]["law"][language], \ translation["accordion"]["info"][language] def predict(image): # Create the array of the right shape to feed into the keras model # The 'length' or number of images you can put into the array is # determined by the first position in the shape tuple, in this case 1. data = np.ndarray(shape=(1, 224, 224, 3), dtype=np.float32) #resize the image to a 224x224 with the same strategy as in TM2: #resizing the image to be at least 224x224 and then cropping from the center size = (224, 224) image = ImageOps.fit(image, size, Image.ANTIALIAS) #turn the image into a numpy array image_array = np.asarray(image) # Normalize the image normalized_image_array = (image_array.astype(np.float32) / 127.0) - 1 # Load the image into the array data[0] = normalized_image_array # run the inference pred = model.predict(data) pred = pred.tolist() return pred result = {} best_prediction = '' def interpret_prediction(prediction): global result sorted_index = np.argsort(prediction).tolist()[0] display_index = [] for i in sorted_index[::-1]: if prediction[0][i] > 0.01: display_index.append(i) # best_prediction = format_label(labels[sorted_index[-1]]).strip() result = {format_label(labels[i]): round(prediction[0][i],2) for i in display_index} # return best_prediction def run_btn_click(image): global best_prediction best_prediction = None global article article = translation["info"]["ATP_contact"][language] interpret_prediction(predict(image)) visible_result = [ False, False, False, False, False ] image_result = [ os.path.join(folder, 'examples', 'empty.JPG'), os.path.join(folder, 'examples', 'empty.JPG'), os.path.join(folder, 'examples', 'empty.JPG'), os.path.join(folder, 'examples', 'empty.JPG'), os.path.join(folder, 'examples', 'empty.JPG') ] percent_result = [ "", "", "", "", "" ] species_result = [ "", "", "", "", "" ] for i, (species, percent) in enumerate(result.items()): print(species, result) visible_result[i] = True image_result[i] = os.path.join(folder, 'examples', f'test_{get_species_abbreviation(species)}.JPG') percent_result[i] = f'{round(percent*100)}%' species_result[i] = species return gr.Accordion.update(open=True, visible=True), \ gr.Image.update(value=image_result[0], visible=visible_result[0]), \ gr.HighlightedText.update(value=[('', percent_result[0])], label=species_result[0], visible=visible_result[0]), \ gr.Button.update(visible=visible_result[0]), \ \ gr.Image.update(value=image_result[1], visible=visible_result[1]), \ gr.HighlightedText.update(value=[('', percent_result[1])], label=species_result[1], visible=visible_result[1]), \ gr.Button.update(visible=visible_result[1]), \ \ gr.Image.update(value=image_result[2], visible=visible_result[2]), \ gr.HighlightedText.update(value=[('', percent_result[2])], label=species_result[2], visible=visible_result[2]), \ gr.Button.update(visible=visible_result[2]), \ \ gr.Image.update(value=image_result[3], visible=visible_result[3]), \ gr.HighlightedText.update(value=[('', percent_result[3])], label=species_result[3], visible=visible_result[3]), \ gr.Button.update(visible=visible_result[3]), \ \ gr.Image.update(value=image_result[4], visible=visible_result[4]), \ gr.HighlightedText.update(value=[('', percent_result[4])], label=species_result[4], visible=visible_result[4]), \ gr.Button.update(visible=visible_result[4]), \ gr.Accordion.update(visible=False), \ [] # gr.Accordion.update(visible=False), \ # gr.Accordion.update(visible=False), \ # gr.Accordion.update(visible=False), \ # gr.Accordion.update(visible=False), \ # gr.Markdown.update(value=percent_result[4], visible=visible_result[4]), \ def get_image_gallery_species_1(): global best_prediction for i, name in enumerate(result): if i == 0: best_prediction = name return glob.glob(os.path.join(folder, 'gallery', name, '*')) def get_image_gallery_species_2(): global best_prediction for i, name in enumerate(result): if i == 1: best_prediction = name return glob.glob(os.path.join(folder, 'gallery', name, '*')) def get_image_gallery_species_3(): global best_prediction for i, name in enumerate(result): if i == 2: best_prediction = name return glob.glob(os.path.join(folder, 'gallery', name, '*')) def get_image_gallery_species_4(): global best_prediction for i, name in enumerate(result): if i == 3: best_prediction = name return glob.glob(os.path.join(folder, 'gallery', name, '*')) def get_image_gallery_species_5(): global best_prediction for i, name in enumerate(result): if i == 4: best_prediction = name return glob.glob(os.path.join(folder, 'gallery', name, '*')) def display_info(): cites, nd06 = get_law(best_prediction) fun_fact = f"{get_fun_fact(best_prediction, language)}." status = f"{get_conservation_status(best_prediction, language)}" law = f'CITES: {cites}, NĐ06: {nd06}' info = "" if str(nd06) != "": law_protection = translation["info"]["law_protection"][language] report = translation["info"]["report"][language] deliver = translation["info"]["deliver"][language] release = translation["info"]["release"][language] + f" **{get_habitat(best_prediction, language)}**" info = f"- {law_protection}\n\n- {report}\n\n- {deliver}\n\n- {release}" return gr.Accordion.update(visible=True), \ gr.Accordion.update(open=False), \ gr.Accordion.update(visible=True), \ gr.Accordion.update(visible=True), \ gr.Accordion.update(visible=True), \ gr.Accordion.update(visible=True), \ fun_fact, status, law, info default_lan = 'Tiếng Việt' with gr.Blocks() as demo: gr.Markdown("# VNTurtle") radio_lan = gr.Radio(choices=['Tiếng Việt', 'English'], value=default_lan, label='Ngôn ngữ/Language', show_label=True, interactive=True) md_des = gr.Markdown(get_description(get_language_code(default_lan))) with gr.Row(equal_height=True): inp = gr.Image(type="pil", show_label=True, label='Ảnh tải lên', interactive=True).style(height=250) gallery = gr.Gallery(show_label=True, label='Ảnh đối chiếu').style(grid=[4], height="auto") with gr.Row(): run_btn = gr.Button(translation['label']['label_run_btn'][get_language_code(default_lan)]) result_verify_btn = gr.Button(translation['label']['label_verify_btn'][get_language_code(default_lan)], visible=True) accordion_result_section = gr.Accordion(translation["accordion"]["result_section"][get_language_code(default_lan)], open=True, visible=False) with accordion_result_section: with gr.Row() as display_result: with gr.Column(scale=0.2, min_width=150) as result_1: result_percent_1 = gr.HighlightedText(show_label=True, visible=False).style(color_map={f'{i}%': 'green' for i in range(101)}) # result_percent_1 = gr.Markdown("", visible=False) result_img_1 = gr.Image(interactive=False, visible=False, show_label=False) result_view_btn_1 = gr.Button(translation['label']['label_check_btn'][get_language_code(default_lan)], visible=False) with gr.Column(scale=0.2, min_width=150) as result_2: result_percent_2 = gr.HighlightedText(show_label=True, visible=False).style(color_map={f'{i}%': 'yellow' for i in range(101)}) result_img_2 = gr.Image(interactive=False, visible=False, show_label=False) result_view_btn_2 = gr.Button(translation['label']['label_check_btn'][get_language_code(default_lan)], visible=False) with gr.Column(scale=0.2, min_width=150) as result_3: result_percent_3 = gr.HighlightedText(show_label=True, visible=False).style(color_map={f'{i}%': 'orange' for i in range(101)}) result_img_3 = gr.Image(interactive=False, visible=False, show_label=False) result_view_btn_3 = gr.Button(translation['label']['label_check_btn'][get_language_code(default_lan)], visible=False) with gr.Column(scale=0.2, min_width=150) as result_4: result_percent_4 = gr.HighlightedText(show_label=True, visible=False).style(color_map={f'{i}%': 'chocolate' for i in range(101)}) result_img_4 = gr.Image(interactive=False, visible=False, show_label=False) result_view_btn_4 = gr.Button(translation['label']['label_check_btn'][get_language_code(default_lan)], visible=False) with gr.Column(scale=0.2, min_width=150) as result_5: result_percent_5 = gr.HighlightedText(show_label=True, visible=False).style(color_map={f'{i}%': 'grey' for i in range(101)}) result_img_5 = gr.Image(interactive=False, visible=False, show_label=False) result_view_btn_5 = gr.Button(translation['label']['label_check_btn'][get_language_code(default_lan)], visible=False) accordion_info_section = gr.Accordion(translation['accordion']['info_section'][get_language_code(default_lan)], visible=False, open=True) with accordion_info_section: accordion_fun_fact = gr.Accordion(translation["accordion"]["fun_fact"][get_language_code(default_lan)], open=False, visible=False) accordion_status = gr.Accordion(translation["accordion"]["status"][get_language_code(default_lan)], open=False, visible=False) accordion_law = gr.Accordion(translation["accordion"]["law"][get_language_code(default_lan)], open=False, visible=False) accordion_info = gr.Accordion(translation["accordion"]["info"][get_language_code(default_lan)], open=False, visible=False) with accordion_fun_fact: md_fun_fact = gr.Markdown() with accordion_status: md_status = gr.Markdown() with accordion_law: md_law = gr.Markdown() with accordion_info: md_info = gr.Markdown() gr.Markdown("---") with gr.Accordion("🌅 Ảnh thử nghiệm", open=False): gr.Examples( examples=[[os.path.join(folder, 'examples', f'test_{get_species_abbreviation(s)}.JPG'), get_name(s, language)] for s in get_species_list()], inputs=[inp], label="" ) radio_lan.change(fn=update_language, inputs=[radio_lan], outputs=[ md_des, run_btn, accordion_fun_fact, accordion_status, accordion_law, accordion_info ]) run_btn.click(fn=run_btn_click, inputs=inp, outputs= [ accordion_result_section, # md_fun_fact, md_status, md_law, md_info, result_img_1, result_percent_1, result_view_btn_1, result_img_2, result_percent_2, result_view_btn_2, result_img_3, result_percent_3, result_view_btn_3, result_img_4, result_percent_4, result_view_btn_4, result_img_5, result_percent_5, result_view_btn_5, # accordion_fun_fact, accordion_status, accordion_law, accordion_info, accordion_info_section, gallery ], show_progress=True, scroll_to_output=True) result_view_btn_1.click(fn=get_image_gallery_species_1, outputs=gallery) result_view_btn_2.click(fn=get_image_gallery_species_2, outputs=gallery) result_view_btn_3.click(fn=get_image_gallery_species_3, outputs=gallery) result_view_btn_4.click(fn=get_image_gallery_species_4, outputs=gallery) result_view_btn_5.click(fn=get_image_gallery_species_5, outputs=gallery) result_verify_btn.click(fn=display_info, outputs=[ accordion_info_section, accordion_result_section, accordion_fun_fact, accordion_status, accordion_law, accordion_info, md_fun_fact, md_status, md_law, md_info, ], scroll_to_output=True) demo.launch(debug=False)