''' Author: Qiguang Chen LastEditors: Qiguang Chen Date: 2023-01-23 17:26:47 LastEditTime: 2023-02-14 20:07:02 Description: ''' import argparse import os import signal import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import time from gradio import networking from common.utils import load_yaml, str2bool import json import threading from flask import Flask, request, render_template, render_template_string def get_example(start, end, predict_data_file_path): data_list = [] with open(predict_data_file_path, "r", encoding="utf8") as f1: for index, line1 in enumerate(f1): if index < start: continue if index > end: break line1 = json.loads(line1.strip()) obj = {"text": line1["text"]} obj["intent"] = [{"intent": line1["golden_intent"], "pred_intent": line1["pred_intent"]}] obj["slot"] = [{"text": t, "pred_slot": ps, "slot": s} for t, s, ps in zip( line1["text"], line1["pred_slot"], line1["golden_slot"])] data_list.append(obj) return data_list def analysis(predict_data_file_path): intent_dict = {} slot_dict = {} sample_num = 0 with open(predict_data_file_path, "r", encoding="utf8") as f1: for index, line1 in enumerate(f1): sample_num += 1 line1 = json.loads(line1.strip()) for s, ps in zip(line1["golden_slot"], line1["pred_slot"]): if s not in slot_dict: slot_dict[s] = {"_error_": 0, "_total_": 0} if s != ps: slot_dict[s]["_error_"] += 1 if ps not in slot_dict[s]: slot_dict[s][ps] = 0 slot_dict[s][ps] += 1 slot_dict[s]["_total_"] += 1 for i, pi in zip([line1["golden_intent"]], [line1["pred_intent"]]): if i not in intent_dict: intent_dict[i] = {"_error_": 0, "_total_": 0} if i != pi: intent_dict[i]["_error_"] += 1 if pi not in intent_dict[i]: intent_dict[i][pi] = 0 intent_dict[i][pi] += 1 intent_dict[i]["_total_"] += 1 intent_dict_list = [{"value": intent_dict[name]["_error_"], "name": name} for name in intent_dict] for intent in intent_dict_list: temp_intent = sorted( intent_dict[intent["name"]].items(), key=lambda d: d[1], reverse=True) # [:7] temp_intent = [[key, value] for key, value in temp_intent] intent_dict[intent["name"]] = temp_intent slot_dict_list = [{"value": slot_dict[name]["_error_"], "name": name} for name in slot_dict] for slot in slot_dict_list: temp_slot = sorted( slot_dict[slot["name"]].items(), key=lambda d: d[1], reverse=True) temp_slot = [[key, value] for key, value in temp_slot] slot_dict[slot["name"]] = temp_slot return intent_dict_list, slot_dict_list, intent_dict, slot_dict, sample_num parser = argparse.ArgumentParser() parser.add_argument('--config_path', '-cp', type=str, default="config/visual.yaml") parser.add_argument('--output_path', '-op', type=str, default=None) parser.add_argument('--push_to_public', '-p', type=str2bool, nargs='?', const=True, default=None, help="Push to public network.(Higher priority than config file)") args = parser.parse_args() button_html = "" config = load_yaml(args.config_path) if args.output_path is not None: config["output_path"] = args.output_path if args.push_to_public is not None: config["is_push_to_public"] = args.push_to_public intent_dict_list, slot_dict_list, intent_dict, slot_dict, sample_num = analysis(config["output_path"]) PAGE_SIZE = config["page-size"] PAGE_NUM = int(sample_num / PAGE_SIZE) + 1 app = Flask(__name__, template_folder="static//template") @app.route("/") def hello(): page = request.args.get('page') if page is None: page = 0 page = int(page) if int(page) >= 0 else 0 init_index = page*PAGE_SIZE examples = get_example(init_index, init_index + PAGE_SIZE - 1, config["output_path"]) return render_template('visualization.html', examples=examples, intent_dict_list=intent_dict_list, slot_dict_list=slot_dict_list, intent_dict=intent_dict, slot_dict=slot_dict, page=page) thread_lock_1 = False class PushToPublicThread(): def __init__(self, config) -> None: self.thread = threading.Thread(target=self.push_to_public, args=(config,)) self.thread_lock_2 = False self.thread.daemon = True def start(self): self.thread.start() def push_to_public(self, config): print("Push visualization results to public by Gradio....") print("Push to URL: ", networking.setup_tunnel(config["host"], str(config["port"]))) print("This share link expires in 72 hours. And do not close this process for public sharing.") while not self.thread_lock_2: continue def exit(self, signum, frame): self.thread_lock_2 = True print("Exit..") os._exit(0) # exit() if __name__ == '__main__': if config["is_push_to_public"]: thread_1 = threading.Thread(target=lambda: app.run( config["host"], config["port"])) thread_1.start() thread_2 = PushToPublicThread(config) signal.signal(signal.SIGINT, thread_2.exit) signal.signal(signal.SIGTERM, thread_2.exit) thread_2.start() while True: time.sleep(1) else: app.run(config["host"], config["port"])