|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""script to annotate the the datasets with using trained attribute prediciton model. |
|
First, we need to launch the NeMo Megatron inference server |
|
Example: |
|
```bash |
|
python examples/nlp/language_modeling/megatron_gpt_eval.py \ |
|
gpt_model_file=/models/TRAINED_ATTR_PREDICTION_MODEL.nemo \ |
|
pipeline_model_parallel_split_rank=0 \ |
|
server=True \ |
|
tensor_model_parallel_size=TP_SIZE \ |
|
pipeline_model_parallel_size=PP_SIZE \ |
|
trainer.precision=bf16 \ |
|
trainer.devices=TP_SIZE*PP_SIZE \ |
|
trainer.num_nodes=1 \ |
|
web_server=False \ |
|
port=1424 |
|
``` |
|
|
|
Then, we can run this script to annotate the dataset. |
|
Example usage: |
|
|
|
python scripts/nlp_language_modeling/sft/attribute_annotate.py --batch_size=1 --host=localhost --input_file_name=input.jsonl --output_file_name=output.jsonl --port_num=1424 |
|
""" |
|
|
|
import json |
|
import os |
|
|
|
import fire |
|
import tqdm |
|
from langchain.prompts.few_shot import PromptTemplate |
|
|
|
from nemo.collections.nlp.modules.common.megatron.retrieval_services.util import text_generation |
|
|
|
langs = [ |
|
'ar', |
|
'bg', |
|
'bn', |
|
'ca', |
|
'cs', |
|
'da', |
|
'de', |
|
'el', |
|
'en', |
|
'eo', |
|
'es', |
|
'eu', |
|
'fa', |
|
'fi', |
|
'fr', |
|
'gl', |
|
'he', |
|
'hu', |
|
'id', |
|
'it', |
|
'ja', |
|
'ko', |
|
'nb', |
|
'nl', |
|
'pl', |
|
'pt', |
|
'ro', |
|
'ru', |
|
'sk', |
|
'sv', |
|
'th', |
|
'tr', |
|
'uk', |
|
'vi', |
|
'zh', |
|
] |
|
|
|
SFT_PREFIX = """<extra_id_0>System |
|
{system_message}""" |
|
|
|
ONE_TRUN_WITH_VAL = """<extra_id_1>{user_name} |
|
{user_message} |
|
<extra_id_2>{label} |
|
""" |
|
|
|
ONE_TRUN_WITHOUT_VAL = """<extra_id_1>{user_name} |
|
{user_message} |
|
""" |
|
SYSTEM = PromptTemplate(input_variables=["system_message"], template=SFT_PREFIX) |
|
EXAMPLE_PROMPT_WITH_VAL = PromptTemplate( |
|
input_variables=["user_name", "user_message", "label"], template=ONE_TRUN_WITH_VAL |
|
) |
|
EXAMPLE_PROMPT_WITHOUT_VAL = PromptTemplate( |
|
input_variables=["user_name", "user_message"], template=ONE_TRUN_WITHOUT_VAL |
|
) |
|
|
|
selected_keys = [ |
|
'quality', |
|
'toxicity', |
|
'humor', |
|
'creativity', |
|
'violence', |
|
'helpfulness', |
|
'not_appropriate', |
|
'hate_speech', |
|
'sexual_content', |
|
'fails_task', |
|
'political_content', |
|
'moral_judgement', |
|
'lang', |
|
] |
|
|
|
|
|
def calculate_key(obj): |
|
return ":".join([item['value'] for item in obj['conversations']]) |
|
|
|
|
|
def load_data(path): |
|
with open(path, 'r', encoding='utf-8') as fin: |
|
for line in fin: |
|
yield json.loads(line) |
|
|
|
|
|
def get_prompt(data_obj, turn, current_label="", label_id=0): |
|
if len(data_obj['conversations']) < turn + 1: |
|
return None |
|
|
|
examples = [] |
|
for i in range(0, turn): |
|
d = data_obj['conversations'][i] |
|
if 'label' in d: |
|
examples.append( |
|
EXAMPLE_PROMPT_WITH_VAL.format( |
|
**{'user_name': d['from'], 'user_message': d['value'], 'label': d['label']} |
|
) |
|
) |
|
else: |
|
examples.append(EXAMPLE_PROMPT_WITHOUT_VAL.format(**{'user_name': d['from'], 'user_message': d['value']})) |
|
|
|
example_text = "".join(examples) |
|
d = data_obj['conversations'][turn] |
|
predict_message = EXAMPLE_PROMPT_WITHOUT_VAL.format(**{'user_name': d['from'], 'user_message': d['value']}) |
|
|
|
if label_id != 0: |
|
current_label = current_label + ',' + selected_keys[label_id] + ':' |
|
else: |
|
current_label = '<extra_id_2>' + selected_keys[label_id] + ':' |
|
return SYSTEM.format(**{'system_message': data_obj['system']}) + example_text + predict_message + current_label |
|
|
|
|
|
def create_gen_function(host='localhost', port=5555): |
|
def request(prompts, greedy, add_BOS, token_to_gen, min_tokens, temp, top_p, top_k, repetition, end_strings): |
|
data = { |
|
"sentences": prompts, |
|
"tokens_to_generate": int(token_to_gen), |
|
"temperature": temp, |
|
"add_BOS": add_BOS, |
|
"top_k": top_k, |
|
"top_p": top_p, |
|
"greedy": greedy, |
|
"all_probs": False, |
|
"repetition_penalty": repetition, |
|
"min_tokens_to_generate": int(min_tokens), |
|
"end_strings": end_strings, |
|
} |
|
response = text_generation(data, ip=host, port=port) |
|
sentences = response['sentences'] |
|
return sentences |
|
|
|
return request |
|
|
|
|
|
class Worker(object): |
|
def __init__(self, host='localhost', port=5555, progress_bar=None, output_file=None, process_lang=False): |
|
self.req = create_gen_function(host=host, port=port) |
|
self.fout = open(output_file, "a", encoding='utf-8') |
|
self.progress_bar = progress_bar |
|
self.process_lang = process_lang |
|
|
|
def process_result(self, batch): |
|
while True: |
|
try: |
|
items = [i['item'] for i in batch] |
|
turns = [i['turn'] for i in batch] |
|
prompts = [i['prompt'] for i in batch] |
|
|
|
for label_id in range(1, len(selected_keys)): |
|
results = self.req( |
|
prompts, |
|
greedy=True, |
|
add_BOS=False, |
|
token_to_gen=1, |
|
min_tokens=1, |
|
temp=0.1, |
|
top_p=1.0, |
|
top_k=1, |
|
repetition=1.0, |
|
end_strings=["<extra_id_1>", "<|endoftext|>"], |
|
) |
|
|
|
current_values = [] |
|
nums = [] |
|
for result in results: |
|
|
|
current_val = result.split('quality')[-1] |
|
current_val = 'quality' + current_val |
|
|
|
current_val = current_val.split('\n')[0].strip() |
|
|
|
splits = current_val.split(',') |
|
filtered = [] |
|
for item in splits: |
|
filtered.append(item) |
|
if item.split(':')[0] == selected_keys[label_id - 1]: |
|
nums.append(item.split(':')[1]) |
|
break |
|
current_val = '<extra_id_2>' + ','.join(filtered) |
|
current_values.append(current_val) |
|
|
|
filtered_items = [] |
|
filtered_turns = [] |
|
filtered_prompts = [] |
|
filtered_current_values = [] |
|
|
|
for result, item, turn, num, current_value in zip(results, items, turns, nums, current_values): |
|
try: |
|
value = int(num) |
|
except Exception as e: |
|
print(f'error {e} when convert {num} to int') |
|
continue |
|
filtered_current_values.append(current_value) |
|
filtered_items.append(item) |
|
filtered_turns.append(turn) |
|
if label_id < len(selected_keys): |
|
prompt = get_prompt(item, turn, current_label=current_value, label_id=label_id) |
|
filtered_prompts.append(prompt) |
|
items = filtered_items |
|
turns = filtered_turns |
|
prompts = filtered_prompts |
|
current_values = filtered_current_values |
|
|
|
if self.process_lang: |
|
results = self.req( |
|
prompts, |
|
greedy=True, |
|
add_BOS=False, |
|
token_to_gen=1, |
|
min_tokens=1, |
|
temp=0.1, |
|
top_p=1.0, |
|
top_k=1, |
|
repetition=1.0, |
|
end_strings=["<extra_id_1>", "<|endoftext|>"], |
|
) |
|
|
|
current_values = [] |
|
for result in results: |
|
|
|
if result.endswith('\n'): |
|
result = result[:-1] + '@' |
|
current_values.append(result.split('\n')[-1]) |
|
|
|
nums = [] |
|
for result in results: |
|
|
|
current_val = result.split('quality')[-1] |
|
current_val = 'quality' + current_val |
|
|
|
current_val = current_val.split('\n')[0].strip() |
|
|
|
splits = current_val.split(',') |
|
filtered = [] |
|
for item in splits: |
|
filtered.append(item) |
|
if item.split(':')[0] == selected_keys[label_id]: |
|
nums.append(item.split(':')[1]) |
|
break |
|
current_val = '<extra_id_2>' + ','.join(filtered) |
|
current_values.append(current_val) |
|
|
|
filtered_items = [] |
|
filtered_turns = [] |
|
filtered_prompts = [] |
|
filtered_current_values = [] |
|
|
|
for result, item, turn, num, current_value in zip(results, items, turns, nums, current_values): |
|
if num not in langs: |
|
print(f'error {num} not in langs') |
|
continue |
|
filtered_current_values.append(current_value) |
|
filtered_items.append(item) |
|
filtered_turns.append(turn) |
|
items = filtered_items |
|
turns = filtered_turns |
|
current_values = filtered_current_values |
|
|
|
batch = [] |
|
for item, turn, current_value in zip(items, turns, current_values): |
|
response_text = current_value[12:] |
|
if 'label' in item['conversations'][turn]: |
|
item['conversations'][turn]['gt_label'] = item['conversations'][turn]['label'] |
|
item['conversations'][turn]['label'] = response_text |
|
prompt = get_prompt(item, turn + 1, current_label='', label_id=0) |
|
if prompt is not None: |
|
batch.append({'prompt': prompt, 'item': item, 'turn': turn + 1}) |
|
else: |
|
self.progress_bar.update(1) |
|
self.fout.write(json.dumps(item, ensure_ascii=False) + "\n") |
|
self.fout.flush() |
|
if self.progress_bar.n >= self.progress_bar.total: |
|
break |
|
if len(batch) == 0: |
|
break |
|
except Exception as e: |
|
print(f'error {e} when processing {batch}') |
|
|
|
self.progress_bar.update(1) |
|
if self.progress_bar.n >= self.progress_bar.total: |
|
break |
|
|
|
|
|
def main( |
|
batch_size=1, |
|
host='localhost', |
|
input_file_name='input.jsonl', |
|
output_file_name='output.jsonl', |
|
port_num=1424, |
|
process_lang=True, |
|
): |
|
input_data = load_data(f'{input_file_name}') |
|
output_path = f'{output_file_name}' |
|
existing_requests = set() |
|
if os.path.exists(output_path): |
|
with open(output_path, 'r', encoding='utf-8') as fin: |
|
for line in fin: |
|
line = json.loads(line) |
|
existing_requests.add(calculate_key(line)) |
|
print(f"Loaded {len(existing_requests)} existing requests") |
|
|
|
filter_data = [d for d in input_data if calculate_key(d) not in existing_requests] |
|
|
|
progress_bar = tqdm.tqdm(total=len(filter_data)) |
|
|
|
worker = Worker( |
|
host=host, port=port_num, progress_bar=progress_bar, output_file=output_path, process_lang=process_lang |
|
) |
|
for batch_idx in range(0, len(filter_data), batch_size): |
|
batch = [line for line in filter_data[batch_idx : batch_idx + batch_size]] |
|
turns = [ |
|
0 if 'mask' not in d['conversations'][0]['from'] or d['conversations'][0]['from'] == d['mask'] else 1 |
|
for d in batch |
|
] |
|
task = [{'prompt': get_prompt(d, turn, "", 0), 'item': d, 'turn': turn} for d, turn in zip(batch, turns)] |
|
worker.process_result(task) |
|
worker.fout.close() |
|
|
|
|
|
if __name__ == '__main__': |
|
fire.Fire(main) |
|
|