|
import torch |
|
import transformers |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig |
|
import openai |
|
from openai import OpenAI |
|
|
|
def hide_encrypt(original_input, hide_model, tokenizer): |
|
hide_template = """<s>Paraphrase the text:%s\n\n""" |
|
input_text = hide_template % original_input |
|
inputs = tokenizer(input_text, return_tensors='pt').to(hide_model.device) |
|
pred = hide_model.generate( |
|
**inputs, |
|
generation_config=GenerationConfig( |
|
max_new_tokens = int(len(inputs['input_ids'][0]) * 1.3), |
|
do_sample=False, |
|
num_beams=3, |
|
repetition_penalty=5.0, |
|
), |
|
) |
|
pred = pred.cpu()[0][len(inputs['input_ids'][0]):] |
|
hide_input = tokenizer.decode(pred, skip_special_tokens=True) |
|
return hide_input |
|
|
|
def seek_decrypt(hide_input, hide_output, original_input, seek_model, tokenizer): |
|
seek_template = """Convert the text:\n%s\n\n%s\n\nConvert the text:\n%s\n\n""" |
|
input_text = seek_template % (hide_input, hide_output, original_input) |
|
inputs = tokenizer(input_text, return_tensors='pt').to(seek_model.device) |
|
pred = seek_model.generate( |
|
**inputs, |
|
generation_config=GenerationConfig( |
|
max_new_tokens = int(len(inputs['input_ids'][0]) * 1.3), |
|
do_sample=False, |
|
num_beams=3, |
|
), |
|
) |
|
pred = pred.cpu()[0][len(inputs['input_ids'][0]):] |
|
original_output = tokenizer.decode(pred, skip_special_tokens=True) |
|
return original_output |
|
|
|
def get_gpt_output(prompt, api_key=None): |
|
if not api_key: |
|
raise ValueError('an open api key is needed for this function') |
|
client = OpenAI(api_key=api_key) |
|
completion = client.chat.completions.create( |
|
model="gpt-3.5-turbo", |
|
messages=[ |
|
{"role": "user", "content": prompt} |
|
] |
|
) |
|
return completion.choices[0].message.content |