File size: 7,105 Bytes
fa6856c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from unittest import TestCase

from hypothesis import given
from hypothesis import strategies as st
from transformers import AutoTokenizer

from trlx.pipeline.offline_pipeline import DialogMessage, tokenize_dialogue


class TestTokenizeDialog(TestCase):
    def setUp(self):
        self.tokenizer = AutoTokenizer.from_pretrained("gpt2")

    def test_tokenize_dialogue_truncation_basic(self):
        dialogue = ["this will be truncated", "."]
        self.tokenizer.truncation_side = "left"

        dialog = tokenize_dialogue(dialogue, self.tokenizer, max_length=2)

        assert len(dialog) == 2
        user_dm, bot_dm = dialog
        assert len(user_dm.tokens) == 1
        assert len(bot_dm.tokens) == 1
        assert user_dm == DialogMessage(is_output=False, tokens=(self.tokenizer.bos_token_id,))
        assert bot_dm == DialogMessage(is_output=True, tokens=(self.tokenizer.eos_token_id,))

    @given(st.lists(st.text(), max_size=32))
    def test_tokenize_dialogue_single_turn(self, response_words):
        response = " ".join(response_words)  # space seperate to make it multiple tokens
        tokenized_response = tuple(self.tokenizer(response, add_special_tokens=False).input_ids)
        tokenized_response = tokenized_response + (self.tokenizer.eos_token_id,)
        dialog = tokenize_dialogue(response, self.tokenizer)

        assert len(dialog) == 2
        user_dm, bot_dm = dialog

        assert user_dm == DialogMessage(is_output=False, tokens=(self.tokenizer.bos_token_id,))
        assert bot_dm == DialogMessage(is_output=True, tokens=tokenized_response)

    @given(st.lists(st.text(), max_size=32), st.integers(min_value=2, max_value=16))
    def test_tokenize_dialogue_single_turn_truncation_right(self, response_words, max_length):
        response = " ".join(response_words)  # space seperate to make it multiple tokens
        self.tokenizer.truncation_side = "right"
        tokenized_response = tuple(self.tokenizer(response, add_special_tokens=False).input_ids)
        tokenized_response = tokenized_response + (self.tokenizer.eos_token_id,)
        dialog = tokenize_dialogue(response, self.tokenizer, max_length=max_length)

        assert len(dialog) == 2
        user_dm, bot_dm = dialog

        assert user_dm == DialogMessage(is_output=False, tokens=(self.tokenizer.bos_token_id,))
        assert bot_dm == DialogMessage(is_output=True, tokens=tokenized_response[: max_length - 1])

        all_tokens = sum((dm.tokens for dm in dialog), ())
        assert len(all_tokens) <= max_length

    @given(st.lists(st.text(), max_size=32), st.integers(min_value=2, max_value=16))
    def test_tokenize_dialogue_single_turn_truncation_left(self, response_words, max_length):
        response = " ".join(response_words)  # space seperate to make it multiple tokens
        self.tokenizer.truncation_side = "left"
        tokenized_response = tuple(self.tokenizer(response, add_special_tokens=False).input_ids)
        tokenized_response += (self.tokenizer.eos_token_id,)
        dialog = tokenize_dialogue(response, self.tokenizer, max_length=max_length)

        # whether or not truncation has happened, user BOS prompt should be present
        assert len(dialog) == 2
        user_dm, bot_dm = dialog
        assert user_dm == DialogMessage(is_output=False, tokens=(self.tokenizer.bos_token_id,))

        if len(tokenized_response) < max_length:
            assert bot_dm == DialogMessage(is_output=True, tokens=tokenized_response)
        else:
            assert bot_dm == DialogMessage(is_output=True, tokens=tokenized_response[-max_length + 1 :])

        all_tokens = sum((dm.tokens for dm in dialog), ())
        assert len(all_tokens) <= max_length

    @given(st.lists(st.tuples(st.text(), st.text()), min_size=1, max_size=32))
    def test_tokenize_dialogue_multi_turn(self, user_response_pairs):
        convo = [[" ".join(user_words), " ".join(response_words)] for user_words, response_words in user_response_pairs]
        flat_convo = sum(convo, [])
        tokenized_flat_convo = tuple(
            tuple(self.tokenizer(turn, add_special_tokens=False).input_ids) for turn in flat_convo
        )
        tokenized_flat_convo = (*tokenized_flat_convo[:-1], (*tokenized_flat_convo[-1], self.tokenizer.eos_token_id))
        dialog = tokenize_dialogue(flat_convo, self.tokenizer)

        dm_convo = [DialogMessage(is_output=i % 2 == 1, tokens=tokens) for i, tokens in enumerate(tokenized_flat_convo)]
        nonempty_dm_convo = [dm for dm in dm_convo if dm.tokens]
        if nonempty_dm_convo[0].is_output:
            nonempty_dm_convo.insert(0, DialogMessage(is_output=False, tokens=(self.tokenizer.eos_token_id,)))

        assert dialog == nonempty_dm_convo

    @given(st.lists(st.tuples(st.text(), st.text()), min_size=1, max_size=32), st.integers(min_value=2, max_value=16))
    def test_tokenize_dialogue_multi_turn_truncation_right(self, user_response_pairs, max_length):
        convo = [[" ".join(user_words), " ".join(response_words)] for user_words, response_words in user_response_pairs]
        flat_convo = sum(convo, [])
        self.tokenizer.truncation_side = "right"
        tokenized_flat_convo = tuple(
            tuple(self.tokenizer(turn, add_special_tokens=False).input_ids) for turn in flat_convo
        )
        tokenized_flat_convo = (*tokenized_flat_convo[:-1], (*tokenized_flat_convo[-1], self.tokenizer.eos_token_id))
        dialog = tokenize_dialogue(flat_convo, self.tokenizer, max_length=max_length)

        all_tokens = sum((dm.tokens for dm in dialog), ())
        should_be_tokens = sum(tokenized_flat_convo, ())[:max_length]
        if dialog[0] == DialogMessage(is_output=False, tokens=(self.tokenizer.eos_token_id,)):
            should_be_tokens = (self.tokenizer.eos_token_id, *should_be_tokens[: max_length - 1])

        assert all_tokens == should_be_tokens
        assert len(all_tokens) <= max_length

    @given(st.lists(st.tuples(st.text(), st.text()), min_size=1, max_size=32), st.integers(min_value=2, max_value=16))
    def test_tokenize_dialogue_multi_turn_truncation_left(self, user_response_pairs, max_length):
        convo = [[" ".join(user_words), " ".join(response_words)] for user_words, response_words in user_response_pairs]
        flat_convo = sum(convo, [])
        self.tokenizer.truncation_side = "left"
        tokenized_flat_convo = tuple(
            tuple(self.tokenizer(turn, add_special_tokens=False).input_ids) for turn in flat_convo
        )
        tokenized_flat_convo = (*tokenized_flat_convo[:-1], (*tokenized_flat_convo[-1], self.tokenizer.eos_token_id))
        dialog = tokenize_dialogue(flat_convo, self.tokenizer, max_length=max_length)

        all_tokens = sum((dm.tokens for dm in dialog), ())
        should_be_tokens = sum(tokenized_flat_convo, ())[-max_length:]
        if dialog[0] == DialogMessage(is_output=False, tokens=(self.tokenizer.eos_token_id,)):
            should_be_tokens = (self.tokenizer.eos_token_id, *should_be_tokens[-max_length + 1 :])

        assert all_tokens == should_be_tokens
        assert len(all_tokens) <= max_length