from argparse import ArgumentParser
from typing import *
from flask import Flask
from flask import request
from sftp import SpanPredictor, Span
parser = ArgumentParser()
parser.add_argument('model', metavar='MODEL_PATH', type=str)
parser.add_argument('-p', metavar='PORT', type=int, default=7749)
parser.add_argument('-d', metavar='DEVICE', type=int, default=-1)
args = parser.parse_args()
template = open('tools/demo/flask_template.html').read()
predictor = SpanPredictor.from_path(args.model, cuda_device=args.d)
app = Flask(__name__)
default_sentence = '因为 आरजू です vegan , هي купил soja .'
def visualized_prediction(inputs: List[str], prediction: Span, prefix=''):
spans = list()
span2event = [[] for _ in inputs]
for event_idx, event in enumerate(prediction):
for arg_idx, arg in enumerate(event):
for token_idx in range(arg.start_idx, arg.end_idx+1):
span2event[token_idx].append((event_idx, arg_idx))
for token_idx, token in enumerate(inputs):
class_labels = ' '.join(
['token'] + [f'{prefix}-arg-{event_idx}-{arg_idx}' for event_idx, arg_idx in span2event[token_idx]]
)
spans.append(f'{token} \n')
for event_idx, event in enumerate(prediction):
spans[event.start_idx] = (
f''
''
f''
+ spans[event.start_idx]
)
spans[event.end_idx] += f'
'.join(arg_tips)
spans[event.end_idx] += f'{arg_tips}\n'
spans[event.end_idx] += '\n'
return(
'