marksverdhei commited on
Commit
614d543
β€’
1 Parent(s): 2f35e98

WIP: Add attempt count

Browse files
Files changed (6) hide show
  1. app.py +4 -68
  2. src/handler.py +77 -0
  3. src/interface.py +68 -0
  4. src/state.py +27 -0
  5. text.py β†’ src/text.py +0 -0
  6. state.py +0 -10
app.py CHANGED
@@ -1,75 +1,11 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import AutoModelForCausalLM
4
- from transformers import AutoTokenizer
5
 
6
- from state import ProgramState
7
- from text import get_text
8
-
9
- STATE = ProgramState(
10
- current_token=20,
11
- )
12
-
13
-
14
- tokenizer = AutoTokenizer.from_pretrained("gpt2")
15
- model = AutoModelForCausalLM.from_pretrained("gpt2")
16
- model.eval()
17
-
18
- all_tokens = tokenizer.encode(get_text())
19
-
20
-
21
- def handle_guess(text: str) -> str:
22
- """
23
- Retreives
24
- """
25
- STATE.current_token += 1
26
- decoded_str = tokenizer.decode(all_tokens[:STATE.current_token])
27
- return decoded_str, ""
28
-
29
-
30
- def get_model_predictions(input_text: str) -> torch.Tensor:
31
- """
32
- Returns the indices as a torch tensor of the top 3 predicted tokens.
33
- """
34
- inputs = tokenizer(input_text, return_tensors="pt")
35
-
36
- with torch.no_grad():
37
- logits = model(**inputs).logits
38
-
39
- last_token = logits[0, -1]
40
- top_3 = torch.topk(last_token, 3)
41
-
42
- return top_3
43
-
44
- def build_demo():
45
- with gr.Blocks() as demo:
46
- gr.Markdown("<h1><center>Can you beat a language model?</center></h1>")
47
-
48
- with gr.Row():
49
- prompt_text = gr.Markdown()
50
- with gr.Row():
51
- with gr.Column():
52
- guess = gr.Textbox(label="Guess!")
53
- guess_btn = gr.Button(value="Guess!")
54
- with gr.Column():
55
- lm_guess = gr.Textbox(label="LM guess")
56
-
57
- guess_btn.click(handle_guess, inputs=guess, outputs=[prompt_text, lm_guess], api_name="get_next_word")
58
- return demo
59
-
60
-
61
- def wip_sign():
62
- with gr.Blocks() as demo:
63
- gr.Markdown("<h1><center>Can you beat a language model?</center></h1>")
64
- with gr.Row():
65
- gr.Markdown("<h1><center>β›”πŸ‘·β€β™‚οΈ Work in progress, come back later </center></h1>")
66
-
67
- return demo
68
 
69
 
70
  def main():
71
- demo = wip_sign()
72
- # demo = build_demo()
73
  demo.launch(debug=True)
74
 
75
 
 
1
+ import logging
2
+ from src import interface
 
 
3
 
4
+ logging.basicConfig(level="DEBUG")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
 
7
  def main():
8
+ demo = interface.get_demo(wip=True)
 
9
  demo.launch(debug=True)
10
 
11
 
src/handler.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import logging
3
+
4
+ from src.state import STATE
5
+ from src.state import tokenizer
6
+ from src.state import model
7
+ from src.text import get_text
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ all_tokens = tokenizer.encode(get_text())
12
+
13
+
14
+ def get_model_predictions(input_text: str) -> torch.Tensor:
15
+ """
16
+ Returns the indices as a torch tensor of the top 3 predicted tokens.
17
+ """
18
+ inputs = tokenizer(input_text, return_tensors="pt")
19
+
20
+ with torch.no_grad():
21
+ logits = model(**inputs).logits
22
+
23
+ last_token = logits[0, -1]
24
+ top_3 = torch.topk(last_token, 3).indices.tolist()
25
+ return top_3
26
+
27
+
28
+ def handle_guess(text: str) -> str:
29
+ """
30
+ *
31
+ * Retreives model predictions and compares the top 3 predicted tokens
32
+ """
33
+ current_tokens = all_tokens[:STATE.current_word_index]
34
+ current_text = tokenizer.decode(current_tokens)
35
+ player_guesses = ""
36
+ lm_guesses = ""
37
+ remaining_attempts = 3
38
+
39
+ if not text:
40
+ return (
41
+ current_text,
42
+ player_guesses,
43
+ lm_guesses,
44
+ remaining_attempts
45
+ )
46
+
47
+ next_token = all_tokens[STATE.current_word_index]
48
+ predicted_token_start = tokenizer.encode(text, add_special_tokens=False)[0]
49
+ predicted_token_whitespace = tokenizer.encode(". " + text, add_special_tokens=False)[1]
50
+ logger.debug("Next token: '{}'".format(tokenizer.convert_ids_to_tokens([next_token])))
51
+ logger.debug(tokenizer.convert_ids_to_tokens([predicted_token_start, predicted_token_whitespace]))
52
+
53
+ guess_is_correct = next_token in (predicted_token_start, predicted_token_whitespace)
54
+
55
+ if guess_is_correct or remaining_attempts == 0:
56
+ STATE.current_word_index += 1
57
+ current_tokens = all_tokens[:STATE.current_word_index]
58
+ remaining_attempts = 3
59
+ STATE.player_guesses = []
60
+ STATE.lm_guesses = []
61
+ else:
62
+ remaining_attempts -= 1
63
+ STATE.player_guesses.append(tokenizer.decode([predicted_token_whitespace]))
64
+
65
+ # FIXME: unoptimized, computing all three every time
66
+ STATE.lm_guesses = get_model_predictions(tokenizer.decode(current_tokens))[:3-remaining_attempts]
67
+ logger.debug(f"lm_guesses: {tokenizer.decode(lm_guesses)}")
68
+
69
+ player_guesses = "\n".join(STATE.player_guesses)
70
+ current_text = tokenizer.decode(current_tokens)
71
+
72
+ return (
73
+ current_text,
74
+ player_guesses,
75
+ lm_guesses,
76
+ remaining_attempts
77
+ )
src/interface.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from src.handler import handle_guess
3
+
4
+
5
+ def build_demo():
6
+ with gr.Blocks() as demo:
7
+ with gr.Row():
8
+ gr.Markdown("<h1><center>Can you beat a language model?</center></h1>")
9
+ with gr.Row():
10
+ gr.Markdown(
11
+ "Can you beat language models at their own game?\n"
12
+ "In this game you're pitted against a language model in the task of, you guessed it, laungage modelling.\n"
13
+ "Your task is to predict the next word given the previous sequence. You will get 3 attempts to guess.\n"
14
+ "The one with the fewest guesses for a given word gets a point."
15
+ )
16
+ with gr.Row():
17
+ prompt_text = gr.Textbox(label="Context", interactive=False)
18
+ with gr.Row():
19
+ with gr.Column():
20
+ player_points = gr.Number(label="your points", interactive=False)
21
+ with gr.Column():
22
+ lm_points = gr.Number(label="LM points", interactive=False)
23
+ with gr.Row():
24
+ with gr.Column():
25
+ remaining_attempts = gr.Number(label="Remaining attempts")
26
+ current_guesses = gr.Textbox(label="Your guesses")
27
+ with gr.Column():
28
+ lm_guesses = gr.Textbox(label="LM guesses")
29
+
30
+ with gr.Row():
31
+ with gr.Column():
32
+ guess = gr.Textbox(label="")
33
+ guess_button = gr.Button(value="Guess!")
34
+
35
+ with gr.Row():
36
+ next_word = gr.Button(value="Next word")
37
+
38
+ guess_button.click(
39
+ handle_guess,
40
+ inputs=guess,
41
+ outputs=[
42
+ prompt_text,
43
+ current_guesses,
44
+ lm_guesses,
45
+ remaining_attempts,
46
+ ],
47
+ )
48
+
49
+ return demo
50
+
51
+
52
+ def wip_sign():
53
+ with gr.Blocks() as demo:
54
+ gr.Markdown("<h1><center>Can you beat a language model?</center></h1>")
55
+ with gr.Row():
56
+ gr.Markdown("<h1><center>β›”πŸ‘·β€β™‚οΈ Work in progress, come back later </center></h1>")
57
+
58
+ return demo
59
+
60
+
61
+ def get_demo(wip=False):
62
+ if wip:
63
+ return wip_sign()
64
+ else:
65
+ return build_demo()
66
+
67
+
68
+
src/state.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from transformers import AutoTokenizer
3
+ from transformers import AutoModelForCausalLM
4
+
5
+
6
+ from dataclasses import dataclass
7
+
8
+ @dataclass
9
+ class ProgramState:
10
+ current_word_index: int
11
+ player_guesses: list
12
+ player_points: int
13
+ lm_guesses: list
14
+ lm_points: int
15
+
16
+
17
+ STATE = ProgramState(
18
+ current_word_index=20,
19
+ player_guesses=[],
20
+ lm_guesses=[],
21
+ player_points=0,
22
+ lm_points=0,
23
+ )
24
+
25
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
26
+ model = AutoModelForCausalLM.from_pretrained("gpt2")
27
+ model.eval()
text.py β†’ src/text.py RENAMED
File without changes
state.py DELETED
@@ -1,10 +0,0 @@
1
- from dataclasses import dataclass
2
-
3
-
4
- @dataclass
5
- class ProgramState:
6
- current_token: int
7
- # full_text: str
8
-
9
- def get_text(self):
10
- pass