from argparse import ArgumentParser from typing import * from flask import Flask from flask import request from sftp import SpanPredictor, Span parser = ArgumentParser() parser.add_argument('model', metavar='MODEL_PATH', type=str) parser.add_argument('-p', metavar='PORT', type=int, default=7749) parser.add_argument('-d', metavar='DEVICE', type=int, default=-1) args = parser.parse_args() template = open('tools/demo/flask_template.html').read() predictor = SpanPredictor.from_path(args.model, cuda_device=args.d) app = Flask(__name__) default_sentence = '因为 आरजू です vegan , هي купил soja .' def visualized_prediction(inputs: List[str], prediction: Span, prefix=''): spans = list() span2event = [[] for _ in inputs] for event_idx, event in enumerate(prediction): for arg_idx, arg in enumerate(event): for token_idx in range(arg.start_idx, arg.end_idx+1): span2event[token_idx].append((event_idx, arg_idx)) for token_idx, token in enumerate(inputs): class_labels = ' '.join( ['token'] + [f'{prefix}-arg-{event_idx}-{arg_idx}' for event_idx, arg_idx in span2event[token_idx]] ) spans.append(f'{token} \n') for event_idx, event in enumerate(prediction): spans[event.start_idx] = ( f'' '' f'' + spans[event.start_idx] ) spans[event.end_idx] += f'
{event.label}
' arg_tips = [] for arg_idx, arg in enumerate(event): arg_tips.append(f'{arg.label}') if len(arg_tips) > 0: arg_tips = '
'.join(arg_tips) spans[event.end_idx] += f'{arg_tips}\n' spans[event.end_idx] += '\n
' return( '
\n' + '\n'.join(spans) + '\n
' ) def structured_prediction(inputs, prediction): ret = list() for event in prediction: event_text, event_label = ' '.join(inputs[event.start_idx: event.end_idx+1]), event.label ret.append(f'
  • ' f'{event_label}: {event_text}
  • ') for arg in event: arg_text = ' '.join(inputs[arg.start_idx: arg.end_idx+1]) ret.append( f'
  •     {arg.label}: {arg_text}
  • ' ) content = '\n'.join(ret) return '\n' @app.route('/') def sftp(): ret = template tokens = request.args.get('sentence') if tokens is not None: ret = ret.replace('DEFAULT_SENTENCE', tokens) sentences = tokens.split('\n') model_outputs = predictor.predict_batch_sentences(sentences, max_tokens=512) vis_pred, str_pred = list(), list() for sent_idx, output in enumerate(model_outputs): vis_pred.append(visualized_prediction(output.sentence, output.span, f'sent{sent_idx}')) str_pred.append(structured_prediction(output.sentence, output.span)) ret = ret.replace('VISUALIZED_PREDICTION', '
    '.join(vis_pred)) ret = ret.replace('STRUCTURED_PREDICTION', '
    '.join(str_pred)) else: ret = ret.replace('DEFAULT_SENTENCE', default_sentence) ret = ret.replace('VISUALIZED_PREDICTION', '') ret = ret.replace('STRUCTURED_PREDICTION', '') return ret app.run(port=args.p)