Spaces:
Running
Running
init
Browse files- README.md +13 -13
- app.py +305 -0
- gptwm.py +114 -0
- requirements.txt +5 -0
- run_detect.py +58 -0
- run_generate.py +106 -0
README.md
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
-
---
|
2 |
-
title: Unigram
|
3 |
-
emoji:
|
4 |
-
colorFrom: purple
|
5 |
-
colorTo:
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 4.7.1
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: mit
|
11 |
-
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
---
|
2 |
+
title: Unigram-Watermark
|
3 |
+
emoji: π
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: indigo
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.7.1
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
2 |
+
# you may not use this file except in compliance with the License.
|
3 |
+
# You may obtain a copy of the License at
|
4 |
+
#
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
#
|
7 |
+
# Unless required by applicable law or agreed to in writing, software
|
8 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
9 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
10 |
+
# See the License for the specific language governing permissions and
|
11 |
+
# limitations under the License.
|
12 |
+
|
13 |
+
import os
|
14 |
+
import argparse
|
15 |
+
from argparse import Namespace
|
16 |
+
from pprint import pprint
|
17 |
+
from functools import partial
|
18 |
+
|
19 |
+
import numpy # for gradio hot reload
|
20 |
+
import gradio as gr
|
21 |
+
import pathlib
|
22 |
+
import torch
|
23 |
+
|
24 |
+
from transformers import (AutoTokenizer,
|
25 |
+
AutoModelForSeq2SeqLM,
|
26 |
+
AutoModelForCausalLM,
|
27 |
+
LogitsProcessorList,
|
28 |
+
LlamaTokenizer)
|
29 |
+
|
30 |
+
from gptwm import GPTWatermarkDetector, GPTWatermarkLogitsWarper
|
31 |
+
|
32 |
+
def str2bool(v):
|
33 |
+
"""Util function for user friendly boolean flag args"""
|
34 |
+
if isinstance(v, bool):
|
35 |
+
return v
|
36 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
37 |
+
return True
|
38 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
39 |
+
return False
|
40 |
+
else:
|
41 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
42 |
+
|
43 |
+
def parse_args():
|
44 |
+
"""Command line argument specification"""
|
45 |
+
|
46 |
+
parser = argparse.ArgumentParser()
|
47 |
+
|
48 |
+
parser.add_argument("--run_gradio",type=str2bool,default=True,help="Whether to launch as a gradio demo. Set to False if not installed and want to just run the stdout version.")
|
49 |
+
parser.add_argument("--model_name", type=str, default="facebook/opt-125m")
|
50 |
+
parser.add_argument("--fraction", type=float, default=0.5)
|
51 |
+
parser.add_argument("--strength", type=float, default=2.0)
|
52 |
+
parser.add_argument("--wm_key", type=int, default=0)
|
53 |
+
parser.add_argument("--max_new_tokens", type=int, default=300)
|
54 |
+
parser.add_argument("--beam_size", type=int, default=None)
|
55 |
+
parser.add_argument("--top_k", type=int, default=None)
|
56 |
+
parser.add_argument("--top_p", type=float, default=0.9)
|
57 |
+
parser.add_argument("--test_min_tokens", type=int, default=200)
|
58 |
+
parser.add_argument("--threshold", type=float, default=6.0)
|
59 |
+
args = parser.parse_args()
|
60 |
+
return args
|
61 |
+
|
62 |
+
def load_model(args):
|
63 |
+
"""Load and return the model and tokenizer"""
|
64 |
+
hf_token = os.getenv('HF_TOKEN')
|
65 |
+
if 'llama' in args.model_name:
|
66 |
+
tokenizer = LlamaTokenizer.from_pretrained(args.model_name, use_auth_token=hf_token, torch_dtype=torch.float16)
|
67 |
+
else:
|
68 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_auth_token=hf_token, torch_dtype=torch.float16)
|
69 |
+
model = AutoModelForCausalLM.from_pretrained(args.model_name, use_auth_token=hf_token, device_map='auto')
|
70 |
+
model.eval()
|
71 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
72 |
+
return model, tokenizer, device
|
73 |
+
|
74 |
+
def generate(prompt, args, model=None, device=None, tokenizer=None):
|
75 |
+
print(f"Generating with {args}")
|
76 |
+
|
77 |
+
watermark_processor = LogitsProcessorList([GPTWatermarkLogitsWarper(fraction=args.fraction,
|
78 |
+
strength=args.strength,
|
79 |
+
vocab_size=model.config.vocab_size,
|
80 |
+
watermark_key=args.wm_key)])
|
81 |
+
|
82 |
+
|
83 |
+
batch = tokenizer(prompt, truncation=True, return_tensors="pt").to(device)
|
84 |
+
num_tokens = len(batch['input_ids'][0])
|
85 |
+
with torch.inference_mode():
|
86 |
+
generate_args = {
|
87 |
+
**batch,
|
88 |
+
'output_scores': True,
|
89 |
+
'return_dict_in_generate': True,
|
90 |
+
'max_new_tokens': args.max_new_tokens,
|
91 |
+
}
|
92 |
+
|
93 |
+
if args.beam_size is not None:
|
94 |
+
generate_args['num_beams'] = args.beam_size
|
95 |
+
else:
|
96 |
+
generate_args['do_sample'] = True
|
97 |
+
generate_args['top_k'] = args.top_k
|
98 |
+
generate_args['top_p'] = args.top_p
|
99 |
+
|
100 |
+
generate_without_watermark = partial(
|
101 |
+
model.generate,
|
102 |
+
**generate_args
|
103 |
+
)
|
104 |
+
output_without_watermark = generate_without_watermark()
|
105 |
+
decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark['sequences'][:, num_tokens:], skip_special_tokens=True)[0]
|
106 |
+
generate_with_watermark = partial(
|
107 |
+
model.generate,
|
108 |
+
logits_processor=watermark_processor,
|
109 |
+
**generate_args
|
110 |
+
)
|
111 |
+
output_with_watermark = generate_with_watermark()
|
112 |
+
decoded_gen_text_with_wm = tokenizer.batch_decode(output_with_watermark['sequences'][:, num_tokens:], skip_special_tokens=True)[0]
|
113 |
+
|
114 |
+
return (prompt,
|
115 |
+
decoded_output_without_watermark,
|
116 |
+
decoded_gen_text_with_wm,
|
117 |
+
args)
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
def detect_demo(input_text, args, device=None, tokenizer=None):
|
122 |
+
|
123 |
+
vocab_size = 50272 if "opt" in args.model_name else tokenizer.vocab_size
|
124 |
+
|
125 |
+
watermark_detector = GPTWatermarkDetector(fraction=args.fraction,
|
126 |
+
strength=args.strength,
|
127 |
+
vocab_size=vocab_size,
|
128 |
+
watermark_key=args.wm_key)
|
129 |
+
output = []
|
130 |
+
html_output = ["Input text is too short to test."]
|
131 |
+
tokens = tokenizer(input_text, add_special_tokens=False)
|
132 |
+
gen_tokens = tokens["input_ids"]
|
133 |
+
if len(gen_tokens)>= args.test_min_tokens:
|
134 |
+
z_score,green_tokens_mask,green_tokens,total_tokens = watermark_detector.detect(gen_tokens)
|
135 |
+
output.append(['z-score', f"{z_score:.3g}"])
|
136 |
+
output.append(['green_tokens', f"{int(green_tokens):d}"])
|
137 |
+
output.append(['total_tokens', f"{int(total_tokens):d}"])
|
138 |
+
tokenarray =[tokens.token_to_chars(i) for i in range(0,len(gen_tokens))]
|
139 |
+
tags = [(f'<span class="green">{input_text[word.start:word.end]}</span>' if b else f'<span class="red">{input_text[word.start:word.end]}</span>') for word, b in zip(tokenarray, green_tokens_mask)]
|
140 |
+
html_output = f'<p>{" ".join(tags)}</p>'
|
141 |
+
else:
|
142 |
+
print(f"Input text is too short to test.")
|
143 |
+
return output,html_output, args
|
144 |
+
|
145 |
+
def run_gradio(args, model=None, device=None, tokenizer=None):
|
146 |
+
"""Define and launch the gradio demo interface"""
|
147 |
+
css = """
|
148 |
+
.green {
|
149 |
+
color: #008000 !important;
|
150 |
+
border: none;
|
151 |
+
font-weight: bold;
|
152 |
+
}
|
153 |
+
.red {
|
154 |
+
color: #ffad99 !important;
|
155 |
+
border: none;
|
156 |
+
font-weight: bold;
|
157 |
+
}
|
158 |
+
"""
|
159 |
+
|
160 |
+
generate_partial = partial(generate, model=model, device=device, tokenizer=tokenizer)
|
161 |
+
detect_partial = partial(detect_demo, device=device, tokenizer=tokenizer)
|
162 |
+
|
163 |
+
with gr.Blocks(css=css) as demo:
|
164 |
+
# Top section, greeting and instructions
|
165 |
+
with gr.Row():
|
166 |
+
with gr.Row():
|
167 |
+
with gr.Column(scale=9):
|
168 |
+
gr.Markdown(
|
169 |
+
"""
|
170 |
+
## π Unigram-Watermark for AI-Generated Text
|
171 |
+
|
172 |
+
## [Paper](https://arxiv.org/abs/2306.17439) [GitHub](https://github.com/XuandongZhao/Unigram-Watermark)
|
173 |
+
"""
|
174 |
+
)
|
175 |
+
|
176 |
+
with gr.Accordion("Abstract",open=True):
|
177 |
+
gr.Markdown(
|
178 |
+
"""
|
179 |
+
We instantiate our language model watermarking with the **Unigram-Watermark**ββa variant of the K-gram watermark.
|
180 |
+
|
181 |
+
We prove that our watermark method enjoys guaranteed generation quality, correctness in watermark detection, and is robust against text editing and paraphrasing.
|
182 |
+
"""
|
183 |
+
)
|
184 |
+
|
185 |
+
gr.Markdown(f"Language model: {args.model_name}")
|
186 |
+
|
187 |
+
# Construct state for parameters, define updates and toggles
|
188 |
+
default_prompt = args.__dict__.pop("default_prompt")
|
189 |
+
session_args = gr.State(value=args)
|
190 |
+
|
191 |
+
with gr.Tab("Method"):
|
192 |
+
with gr.Accordion("Watermark process",open=True):
|
193 |
+
gr.Markdown(
|
194 |
+
"""
|
195 |
+
1. Randomly partition the vocabulary into two distinct sets: the green list with $\gamma N$ tokens and the red list with the remaining tokens.
|
196 |
+
2. In $\hat{M}$, the logits of the language model for the green list tokens are increased by $\delta$ while the logits for tokens in the red list remain unchanged.
|
197 |
+
"""
|
198 |
+
)
|
199 |
+
with gr.Accordion("Detect process",open=True):
|
200 |
+
gr.Markdown(
|
201 |
+
"""
|
202 |
+
1. Count the number of green tokens in the suspect text.
|
203 |
+
|
204 |
+
2. Normalize the test-statistic $z_{y}=(|y|_G-\gamma n) / \sqrt{n \gamma(1-\gamma)}$.
|
205 |
+
|
206 |
+
3. Make a calibrated decision on whether we think the suspect text is generated from $\hat{M}$ or not.
|
207 |
+
"""
|
208 |
+
)
|
209 |
+
with gr.Tab("Generate and Detect"):
|
210 |
+
|
211 |
+
with gr.Row():
|
212 |
+
prompt = gr.Textbox(label=f"Prompt", interactive=True,lines=10,max_lines=10, value=default_prompt)
|
213 |
+
with gr.Row():
|
214 |
+
generate_btn = gr.Button("Generate")
|
215 |
+
with gr.Row():
|
216 |
+
with gr.Column(scale=1):
|
217 |
+
with gr.Tab("Output Without Watermark"):
|
218 |
+
output_without_watermark = gr.Textbox(label="Text", interactive=False,lines=10,max_lines=10)
|
219 |
+
with gr.Tab("Visualization"):# οΏ₯
|
220 |
+
html_without_watermark = gr.HTML(elem_id="html-without-watermark")
|
221 |
+
with gr.Column(scale=1):
|
222 |
+
without_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2)
|
223 |
+
with gr.Row():
|
224 |
+
with gr.Column(scale=1):
|
225 |
+
with gr.Tab("Output With Watermark"):
|
226 |
+
output_with_watermark = gr.Textbox(label="Text", interactive=False,lines=10,max_lines=10)
|
227 |
+
with gr.Tab("Visualization"):#
|
228 |
+
html_with_watermark = gr.HTML(elem_id="html-with-watermark")
|
229 |
+
with gr.Column(scale=1):
|
230 |
+
with_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"],interactive=False,row_count=7,col_count=2)
|
231 |
+
|
232 |
+
redecoded_input = gr.Textbox(visible=False)
|
233 |
+
truncation_warning = gr.Number(visible=False)
|
234 |
+
def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args):
|
235 |
+
if truncation_warning:
|
236 |
+
return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]", args
|
237 |
+
else:
|
238 |
+
return orig_prompt, args
|
239 |
+
|
240 |
+
with gr.Tab("Detector Only"):
|
241 |
+
with gr.Row():
|
242 |
+
with gr.Column(scale=2):
|
243 |
+
# detect inputbox
|
244 |
+
with gr.Tab("Text to Analyze"):
|
245 |
+
detection_input = gr.Textbox(label="Input", interactive=True,lines=14,max_lines=14)
|
246 |
+
with gr.Tab("Visualization"):
|
247 |
+
html_detection = gr.HTML(elem_id="html-detection")
|
248 |
+
with gr.Column(scale=1):
|
249 |
+
detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2)
|
250 |
+
with gr.Row():
|
251 |
+
# detect
|
252 |
+
detect_btn = gr.Button("Detect")
|
253 |
+
|
254 |
+
|
255 |
+
|
256 |
+
generate_btn.click(fn=generate_partial, inputs=[prompt,session_args], outputs=[redecoded_input, output_without_watermark, output_with_watermark,session_args])
|
257 |
+
# Show truncated version of prompt if truncation occurred
|
258 |
+
redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input,truncation_warning,prompt,session_args], outputs=[prompt,session_args])
|
259 |
+
# Call detection when the outputs (of the generate function) are updated
|
260 |
+
output_without_watermark.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,html_without_watermark,session_args])
|
261 |
+
output_with_watermark.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,html_with_watermark,session_args])
|
262 |
+
# Register main detection tab click
|
263 |
+
detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result, html_detection,session_args])
|
264 |
+
|
265 |
+
|
266 |
+
demo.launch()
|
267 |
+
|
268 |
+
def main(args):
|
269 |
+
"""Run a command line version of the generation and detection operations
|
270 |
+
and optionally launch and serve the gradio demo"""
|
271 |
+
# Initial arg processing and log
|
272 |
+
|
273 |
+
model, tokenizer, device = load_model(args)
|
274 |
+
|
275 |
+
# Generate and detect, report to stdout
|
276 |
+
input_text = (
|
277 |
+
"One tank tumbled down an embankment into the Tenaru River, drowning its crew."
|
278 |
+
" At 23:00 on 14 September, the remnants of the Kuma battalion conducted another attack on the same portion of the Marine lines, but were repulsed. "
|
279 |
+
"A final \"weak\" attack by the Kuma unit on the evening of 15 September was also defeated. Oka's unit of about 650 men attacked the Marines at several locations on the west side of the Lunga perimeter."
|
280 |
+
" At about 04:00 on 14 September, two Japanese companies attacked positions held by the 3rd Battalion, 5th Marine Regiment (3/5) near the coast and were thrown back with heavy losses."
|
281 |
+
" Another Japanese company captured a small ridge somewhat inland but was then pinned down by Marine artillery fire throughout the day and took heavy losses before withdrawing on the evening of 14 September."
|
282 |
+
" The rest of Oka's unit failed to find the Marine lines and did not participate in the attack. "
|
283 |
+
"At 13:05 on 14 September, Kawaguchi led the survivors of his shattered brigade away from the ridge and deeper into the jungle, where they rested and tended to their wounded all the next day. "
|
284 |
+
"Kawaguchi's units were then ordered to withdraw west to the Matanikau River valley to join with Oka's unit, a march over difficult terrain."
|
285 |
+
" Kawaguchi's troops began the march on the morning of 16 September."
|
286 |
+
" Almost every soldier able to walk had to help carry the wounded. "
|
287 |
+
"As the march progressed, the exhausted and hungry soldiers, who had eaten their last rations on the morning before their withdrawal, began to discard their heavy equipment and then their rifles. "
|
288 |
+
"By the time most of them reached Oka's positions at Kokumbona five days later, only half still carried their weapons."
|
289 |
+
" The Kuma battalion's survivors, attempting to follow Kawaguchi's Center Body forces, became lost, wandered for three weeks in the jungle, and almost starved to death before finally reaching Kawaguchi's camp."
|
290 |
+
)
|
291 |
+
|
292 |
+
args.default_prompt = input_text
|
293 |
+
|
294 |
+
# Launch the app to generate and detect interactively (implements the hf space demo)
|
295 |
+
if args.run_gradio:
|
296 |
+
run_gradio(args, model=model, tokenizer=tokenizer, device=device)
|
297 |
+
|
298 |
+
return
|
299 |
+
|
300 |
+
if __name__ == "__main__":
|
301 |
+
|
302 |
+
args = parse_args()
|
303 |
+
print(args)
|
304 |
+
|
305 |
+
main(args)
|
gptwm.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
from typing import List
|
3 |
+
import numpy as np
|
4 |
+
from scipy.stats import norm
|
5 |
+
import torch
|
6 |
+
from transformers import LogitsWarper
|
7 |
+
|
8 |
+
|
9 |
+
class GPTWatermarkBase:
|
10 |
+
"""
|
11 |
+
Base class for watermarking distributions with fixed-group green-listed tokens.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
fraction: The fraction of the distribution to be green-listed.
|
15 |
+
strength: The strength of the green-listing. Higher values result in higher logit scores for green-listed tokens.
|
16 |
+
vocab_size: The size of the vocabulary.
|
17 |
+
watermark_key: The random seed for the green-listing.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, fraction: float = 0.5, strength: float = 2.0, vocab_size: int = 50257, watermark_key: int = 0):
|
21 |
+
rng = np.random.default_rng(self._hash_fn(watermark_key))
|
22 |
+
mask = np.array([True] * int(fraction * vocab_size) + [False] * (vocab_size - int(fraction * vocab_size)))
|
23 |
+
rng.shuffle(mask)
|
24 |
+
self.green_list_mask = torch.tensor(mask, dtype=torch.float32)
|
25 |
+
self.strength = strength
|
26 |
+
self.fraction = fraction
|
27 |
+
|
28 |
+
@staticmethod
|
29 |
+
def _hash_fn(x: int) -> int:
|
30 |
+
"""solution from https://stackoverflow.com/questions/67219691/python-hash-function-that-returns-32-or-64-bits"""
|
31 |
+
x = np.int64(x)
|
32 |
+
return int.from_bytes(hashlib.sha256(x).digest()[:4], 'little')
|
33 |
+
|
34 |
+
|
35 |
+
class GPTWatermarkLogitsWarper(GPTWatermarkBase, LogitsWarper):
|
36 |
+
"""
|
37 |
+
LogitsWarper for watermarking distributions with fixed-group green-listed tokens.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
fraction: The fraction of the distribution to be green-listed.
|
41 |
+
strength: The strength of the green-listing. Higher values result in higher logit scores for green-listed tokens.
|
42 |
+
vocab_size: The size of the vocabulary.
|
43 |
+
watermark_key: The random seed for the green-listing.
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(self, *args, **kwargs):
|
47 |
+
super().__init__(*args, **kwargs)
|
48 |
+
|
49 |
+
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor:
|
50 |
+
"""Add the watermark to the logits and return new logits."""
|
51 |
+
watermark = self.strength * self.green_list_mask
|
52 |
+
new_logits = scores + watermark.to(scores.device)
|
53 |
+
return new_logits
|
54 |
+
|
55 |
+
|
56 |
+
class GPTWatermarkDetector(GPTWatermarkBase):
|
57 |
+
"""
|
58 |
+
Class for detecting watermarks in a sequence of tokens.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
fraction: The fraction of the distribution to be green-listed.
|
62 |
+
strength: The strength of the green-listing. Higher values result in higher logit scores for green-listed tokens.
|
63 |
+
vocab_size: The size of the vocabulary.
|
64 |
+
watermark_key: The random seed for the green-listing.
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, *args, **kwargs):
|
68 |
+
super().__init__(*args, **kwargs)
|
69 |
+
|
70 |
+
@staticmethod
|
71 |
+
def _z_score(num_green: int, total: int, fraction: float) -> float:
|
72 |
+
"""Calculate and return the z-score of the number of green tokens in a sequence."""
|
73 |
+
return (num_green - fraction * total) / np.sqrt(fraction * (1 - fraction) * total)
|
74 |
+
|
75 |
+
@staticmethod
|
76 |
+
def _compute_tau(m: int, N: int, alpha: float) -> float:
|
77 |
+
"""
|
78 |
+
Compute the threshold tau for the dynamic thresholding.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
m: The number of unique tokens in the sequence.
|
82 |
+
N: Vocabulary size.
|
83 |
+
alpha: The false positive rate to control.
|
84 |
+
Returns:
|
85 |
+
The threshold tau.
|
86 |
+
"""
|
87 |
+
factor = np.sqrt(1 - (m - 1) / (N - 1))
|
88 |
+
tau = factor * norm.ppf(1 - alpha)
|
89 |
+
return tau
|
90 |
+
|
91 |
+
def detect(self, sequence: List[int]) -> float:
|
92 |
+
"""Detect the watermark in a sequence of tokens and return the z value."""
|
93 |
+
green_tokens = int(sum(self.green_list_mask[i] for i in sequence))
|
94 |
+
green_tokens_mask = []
|
95 |
+
for i in sequence:
|
96 |
+
if self.green_list_mask[i]:
|
97 |
+
green_tokens_mask.append(True)
|
98 |
+
else:
|
99 |
+
green_tokens_mask.append(False)
|
100 |
+
# self.green_tokens_mask = green_tokens_mask
|
101 |
+
|
102 |
+
return self._z_score(green_tokens, len(sequence), self.fraction), green_tokens_mask,green_tokens,len(sequence)
|
103 |
+
|
104 |
+
def unidetect(self, sequence: List[int]) -> float:
|
105 |
+
"""Detect the watermark in a sequence of tokens and return the z value. Just for unique tokens."""
|
106 |
+
sequence = list(set(sequence))
|
107 |
+
green_tokens = int(sum(self.green_list_mask[i] for i in sequence))
|
108 |
+
return self._z_score(green_tokens, len(sequence), self.fraction)
|
109 |
+
|
110 |
+
def dynamic_threshold(self, sequence: List[int], alpha: float, vocab_size: int) -> (bool, float):
|
111 |
+
"""Dynamic thresholding for watermark detection. True if the sequence is watermarked, False otherwise."""
|
112 |
+
z_score = self.unidetect(sequence)
|
113 |
+
tau = self._compute_tau(len(list(set(sequence))), vocab_size, alpha)
|
114 |
+
return z_score > tau, z_score
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
transformers
|
3 |
+
scipy
|
4 |
+
accelerate
|
5 |
+
pathlib
|
run_detect.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
from tqdm import tqdm
|
4 |
+
import torch
|
5 |
+
from transformers import AutoTokenizer, LlamaTokenizer
|
6 |
+
from gptwm import GPTWatermarkDetector
|
7 |
+
|
8 |
+
|
9 |
+
def main(args):
|
10 |
+
with open(args.input_file, 'r') as f:
|
11 |
+
data = [json.loads(x) for x in f.read().strip().split("\n")]
|
12 |
+
if 'llama' in args.model_name:
|
13 |
+
tokenizer = LlamaTokenizer.from_pretrained(args.model_name, torch_dtype=torch.float16)
|
14 |
+
else:
|
15 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name, torch_dtype=torch.float16)
|
16 |
+
|
17 |
+
vocab_size = 50272 if "opt" in args.model_name else tokenizer.vocab_size
|
18 |
+
|
19 |
+
detector = GPTWatermarkDetector(fraction=args.fraction,
|
20 |
+
strength=args.strength,
|
21 |
+
vocab_size=vocab_size,
|
22 |
+
watermark_key=args.wm_key)
|
23 |
+
|
24 |
+
z_score_list = []
|
25 |
+
for idx, cur_data in tqdm(enumerate(data), total=len(data)):
|
26 |
+
gen_tokens = tokenizer(cur_data['gen_completion'][0], add_special_tokens=False)["input_ids"]
|
27 |
+
if len(gen_tokens) >= args.test_min_tokens:
|
28 |
+
z_score_list.append(detector.detect(gen_tokens))
|
29 |
+
else:
|
30 |
+
print(f"Warning: sequence {idx} is too short to test.")
|
31 |
+
|
32 |
+
save_dict = {
|
33 |
+
'z_score': z_score_list,
|
34 |
+
'wm_pred': [1 if z > args.threshold else 0 for z in z_score_list]
|
35 |
+
}
|
36 |
+
|
37 |
+
print(save_dict)
|
38 |
+
with open(args.input_file.replace('.jsonl', '_z.jsonl'), 'w') as f:
|
39 |
+
json.dump(save_dict, f)
|
40 |
+
|
41 |
+
print('Finished!')
|
42 |
+
|
43 |
+
|
44 |
+
if __name__ == "__main__":
|
45 |
+
parser = argparse.ArgumentParser()
|
46 |
+
|
47 |
+
# parser.add_argument("--model_name", type=str, default="facebook/opt-125m")
|
48 |
+
parser.add_argument("--model_name", type=str, default="decapoda-research/llama-7b-hf")
|
49 |
+
parser.add_argument("--fraction", type=float, default=0.5)
|
50 |
+
parser.add_argument("--strength", type=float, default=2.0)
|
51 |
+
parser.add_argument("--threshold", type=float, default=6.0)
|
52 |
+
parser.add_argument("--wm_key", type=int, default=0)
|
53 |
+
parser.add_argument("--input_file", type=str, default="./data/example_output.jsonl")
|
54 |
+
parser.add_argument("--test_min_tokens", type=int, default=200)
|
55 |
+
|
56 |
+
args = parser.parse_args()
|
57 |
+
|
58 |
+
main(args)
|
run_generate.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from tqdm import tqdm
|
3 |
+
import json
|
4 |
+
import torch
|
5 |
+
import os
|
6 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, LogitsProcessorList
|
7 |
+
from gptwm import GPTWatermarkLogitsWarper
|
8 |
+
|
9 |
+
|
10 |
+
def read_file(filename):
|
11 |
+
with open(filename, "r") as f:
|
12 |
+
return [json.loads(line) for line in f.read().strip().split("\n")]
|
13 |
+
|
14 |
+
|
15 |
+
def write_file(filename, data):
|
16 |
+
with open(filename, "a") as f:
|
17 |
+
f.write("\n".join(data) + "\n")
|
18 |
+
|
19 |
+
|
20 |
+
def main(args):
|
21 |
+
output_file = f"{args.output_dir}/{args.model_name.replace('/', '-')}_strength_{args.strength}_frac_{args.fraction}_len_{args.max_new_tokens}_num_{args.num_test}.jsonl"
|
22 |
+
if 'llama' in args.model_name:
|
23 |
+
tokenizer = LlamaTokenizer.from_pretrained(args.model_name, torch_dtype=torch.float16)
|
24 |
+
else:
|
25 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name, torch_dtype=torch.float16)
|
26 |
+
model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map='auto')
|
27 |
+
model.eval()
|
28 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
29 |
+
watermark_processor = LogitsProcessorList([GPTWatermarkLogitsWarper(fraction=args.fraction,
|
30 |
+
strength=args.strength,
|
31 |
+
vocab_size=model.config.vocab_size,
|
32 |
+
watermark_key=args.wm_key)])
|
33 |
+
|
34 |
+
data = read_file(args.prompt_file)
|
35 |
+
num_cur_outputs = len(read_file(output_file)) if os.path.exists(output_file) else 0
|
36 |
+
|
37 |
+
outputs = []
|
38 |
+
|
39 |
+
for idx, cur_data in tqdm(enumerate(data), total=min(len(data), args.num_test)):
|
40 |
+
if idx < num_cur_outputs or len(outputs) >= args.num_test:
|
41 |
+
continue
|
42 |
+
|
43 |
+
if "gold_completion" not in cur_data and 'targets' not in cur_data:
|
44 |
+
continue
|
45 |
+
elif "gold_completion" in cur_data:
|
46 |
+
prefix = cur_data['prefix']
|
47 |
+
gold_completion = cur_data['gold_completion']
|
48 |
+
else:
|
49 |
+
prefix = cur_data['prefix']
|
50 |
+
gold_completion = cur_data['targets'][0]
|
51 |
+
|
52 |
+
batch = tokenizer(prefix, truncation=True, return_tensors="pt").to(device)
|
53 |
+
num_tokens = len(batch['input_ids'][0])
|
54 |
+
|
55 |
+
with torch.inference_mode():
|
56 |
+
generate_args = {
|
57 |
+
**batch,
|
58 |
+
'logits_processor': watermark_processor,
|
59 |
+
'output_scores': True,
|
60 |
+
'return_dict_in_generate': True,
|
61 |
+
'max_new_tokens': args.max_new_tokens,
|
62 |
+
}
|
63 |
+
|
64 |
+
if args.beam_size is not None:
|
65 |
+
generate_args['num_beams'] = args.beam_size
|
66 |
+
else:
|
67 |
+
generate_args['do_sample'] = True
|
68 |
+
generate_args['top_k'] = args.top_k
|
69 |
+
generate_args['top_p'] = args.top_p
|
70 |
+
|
71 |
+
generation = model.generate(**generate_args)
|
72 |
+
gen_text = tokenizer.batch_decode(generation['sequences'][:, num_tokens:], skip_special_tokens=True)
|
73 |
+
|
74 |
+
outputs.append(json.dumps({
|
75 |
+
"prefix": prefix,
|
76 |
+
"gold_completion": gold_completion,
|
77 |
+
"gen_completion": gen_text
|
78 |
+
}))
|
79 |
+
|
80 |
+
if (idx + 1) % 10 == 0:
|
81 |
+
write_file(output_file, outputs)
|
82 |
+
outputs = []
|
83 |
+
break
|
84 |
+
|
85 |
+
write_file(output_file, outputs)
|
86 |
+
print("Finished!")
|
87 |
+
|
88 |
+
|
89 |
+
if __name__ == "__main__":
|
90 |
+
parser = argparse.ArgumentParser()
|
91 |
+
|
92 |
+
parser.add_argument("--model_name", type=str, default="facebookopt-125m")
|
93 |
+
# parser.add_argument("--model_name", type=str, default="decapoda-research/llama-7b-hf")
|
94 |
+
parser.add_argument("--fraction", type=float, default=0.5)
|
95 |
+
parser.add_argument("--strength", type=float, default=2.0)
|
96 |
+
parser.add_argument("--wm_key", type=int, default=0)
|
97 |
+
parser.add_argument("--prompt_file", type=str, default="./data/LFQA/inputs.jsonl")
|
98 |
+
parser.add_argument("--output_dir", type=str, default="./data/LFQA/")
|
99 |
+
parser.add_argument("--max_new_tokens", type=int, default=300)
|
100 |
+
parser.add_argument("--num_test", type=int, default=500)
|
101 |
+
parser.add_argument("--beam_size", type=int, default=None)
|
102 |
+
parser.add_argument("--top_k", type=int, default=None)
|
103 |
+
parser.add_argument("--top_p", type=float, default=0.9)
|
104 |
+
|
105 |
+
args = parser.parse_args()
|
106 |
+
main(args)
|