File size: 13,813 Bytes
d754e91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
import gradio as gr
import time

import torch
import transformers
from transformers import GenerationConfig

from ..globals import Global
from ..models import get_model_with_lora, get_tokenizer, get_device
from ..utils.data import get_available_template_names
from ..utils.prompter import Prompter
from ..utils.callbacks import Iteratorize, Stream

device = get_device()


def inference(
    lora_model_name,
    prompt_template,
    variable_0, variable_1, variable_2, variable_3,
    variable_4, variable_5, variable_6, variable_7,
    temperature=0.1,
    top_p=0.75,
    top_k=40,
    num_beams=4,
    repetition_penalty=1.2,
    max_new_tokens=128,
    stream_output=False,
    **kwargs,
):
    variables = [variable_0, variable_1, variable_2, variable_3,
                 variable_4, variable_5, variable_6, variable_7]
    prompter = Prompter(prompt_template)
    prompt = prompter.generate_prompt(variables)

    if Global.ui_dev_mode:
        message = f"Currently in UI dev mode, not running actual inference.\n\nYour prompt is:\n\n{prompt}"
        print(message)
        time.sleep(1)
        yield message
        return

    model = get_model_with_lora(lora_model_name)
    tokenizer = get_tokenizer()

    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"].to(device)
    generation_config = GenerationConfig(
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        num_beams=num_beams,
        **kwargs,
    )

    generate_params = {
        "input_ids": input_ids,
        "generation_config": generation_config,
        "return_dict_in_generate": True,
        "output_scores": True,
        "max_new_tokens": max_new_tokens,
    }

    if stream_output:
        # Stream the reply 1 token at a time.
        # This is based on the trick of using 'stopping_criteria' to create an iterator,
        # from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.

        def generate_with_callback(callback=None, **kwargs):
            kwargs.setdefault(
                "stopping_criteria", transformers.StoppingCriteriaList()
            )
            kwargs["stopping_criteria"].append(
                Stream(callback_func=callback)
            )
            with torch.no_grad():
                model.generate(**kwargs)

        def generate_with_streaming(**kwargs):
            return Iteratorize(
                generate_with_callback, kwargs, callback=None
            )

        with generate_with_streaming(**generate_params) as generator:
            for output in generator:
                # new_tokens = len(output) - len(input_ids[0])
                decoded_output = tokenizer.decode(output)

                if output[-1] in [tokenizer.eos_token_id]:
                    break

                yield prompter.get_response(decoded_output)
        return  # early return for stream_output

    # Without streaming
    with torch.no_grad():
        generation_output = model.generate(
            input_ids=input_ids,
            generation_config=generation_config,
            return_dict_in_generate=True,
            output_scores=True,
            max_new_tokens=max_new_tokens,
        )
    s = generation_output.sequences[0]
    output = tokenizer.decode(s)
    yield prompter.get_response(output)


def reload_selections(current_lora_model, current_prompt_template):
    available_template_names = get_available_template_names()
    available_template_names_with_none = available_template_names + ["None"]

    if current_prompt_template not in available_template_names_with_none:
        current_prompt_template = None

    current_prompt_template = current_prompt_template or next(
        iter(available_template_names_with_none), None)

    default_lora_models = ["tloen/alpaca-lora-7b"]
    available_lora_models = default_lora_models

    current_lora_model = current_lora_model or next(
        iter(available_lora_models), None)

    return (gr.Dropdown.update(choices=available_lora_models, value=current_lora_model),
            gr.Dropdown.update(choices=available_template_names_with_none, value=current_prompt_template))


def handle_prompt_template_change(prompt_template):
    prompter = Prompter(prompt_template)
    var_names = prompter.get_variable_names()
    human_var_names = [' '.join(word.capitalize()
                                for word in item.split('_')) for item in var_names]
    gr_updates = [gr.Textbox.update(
        label=name, visible=True) for name in human_var_names]
    while len(gr_updates) < 8:
        gr_updates.append(gr.Textbox.update(
            label="Not Used", visible=False))
    return gr_updates


def inference_ui():
    with gr.Blocks() as inference_ui_blocks:
        with gr.Row():
            lora_model = gr.Dropdown(
                label="LoRA Model",
                elem_id="inference_lora_model",
                value="tloen/alpaca-lora-7b",
                allow_custom_value=True,
            )
            prompt_template = gr.Dropdown(
                label="Prompt Template",
                elem_id="inference_prompt_template",
            )
            reload_selections_button = gr.Button(
                "Reload",
                elem_id="inference_reload_selections_button"
            )
            reload_selections_button.style(
                full_width=False,
                size="sm")
        with gr.Row():
            with gr.Column():
                with gr.Column():
                    variable_0 = gr.Textbox(lines=2, label="Prompt")
                    variable_1 = gr.Textbox(lines=2, label="", visible=False)
                    variable_2 = gr.Textbox(lines=2, label="", visible=False)
                    variable_3 = gr.Textbox(lines=2, label="", visible=False)
                    variable_4 = gr.Textbox(lines=2, label="", visible=False)
                    variable_5 = gr.Textbox(lines=2, label="", visible=False)
                    variable_6 = gr.Textbox(lines=2, label="", visible=False)
                    variable_7 = gr.Textbox(lines=2, label="", visible=False)

                with gr.Column():
                    with gr.Row():
                        generate_btn = gr.Button(
                            "Generate", variant="primary", label="Generate", elem_id="inference_generate_btn",
                        )
                        stop_btn = gr.Button(
                            "Stop", variant="stop", label="Stop Iterating", elem_id="inference_stop_btn")

                with gr.Column():
                    temperature = gr.Slider(
                        minimum=0.01, maximum=1.99, value=0.1, step=0.01,
                        label="Temperature",
                        elem_id="inference_temperature"
                    )

                    top_p = gr.Slider(
                        minimum=0, maximum=1, value=0.75, step=0.01,
                        label="Top P",
                        elem_id="inference_top_p"
                    )

                    top_k = gr.Slider(
                        minimum=0, maximum=200, value=40, step=1,
                        label="Top K",
                        elem_id="inference_top_k"
                    )

                    num_beams = gr.Slider(
                        minimum=1, maximum=4, value=1, step=1,
                        label="Beams",
                        elem_id="inference_beams"
                    )

                    repetition_penalty = gr.Slider(
                        minimum=0, maximum=2.5, value=1.2, step=0.01,
                        label="Repetition Penalty",
                        elem_id="inference_repetition_penalty"
                    )

                    max_new_tokens = gr.Slider(
                        minimum=0, maximum=4096, value=128, step=1,
                        label="Max New Tokens",
                        elem_id="inference_max_new_tokens"
                    )

                    stream_output = gr.Checkbox(
                        label="Stream Output",
                        elem_id="inference_stream_output",
                        value=True
                    )
            with gr.Column():
                inference_output = gr.Textbox(
                    lines=12, label="Output Text", elem_id="inference_output")

        reload_selections_button.click(
            reload_selections,
            inputs=[lora_model, prompt_template],
            outputs=[lora_model, prompt_template],
        )

        prompt_template.change(fn=handle_prompt_template_change, inputs=[prompt_template], outputs=[
                               variable_0, variable_1, variable_2, variable_3, variable_4, variable_5, variable_6, variable_7])

        generate_event = generate_btn.click(
            fn=inference,
            inputs=[
                lora_model,
                prompt_template,
                variable_0, variable_1, variable_2, variable_3,
                variable_4, variable_5, variable_6, variable_7,
                temperature,
                top_p,
                top_k,
                num_beams,
                repetition_penalty,
                max_new_tokens,
                stream_output,
            ],
            outputs=inference_output
        )
        stop_btn.click(fn=None, inputs=None, outputs=None,
                       cancels=[generate_event])

    inference_ui_blocks.load(_js="""
    function inference_ui_blocks_js() {
      // Auto load options
      setTimeout(function () {
        document.getElementById("inference_reload_selections_button").click();

        // Workaround default value not shown.
        document.querySelector('#inference_lora_model input').value = "tloen/alpaca-lora-7b";
      }, 100);

      // Add tooltips
      setTimeout(function () {

        tippy("#inference_prompt_template", {
         placement: 'bottom-start',
         delay: [500, 0],
         content: 'Templates are loaded from the "templates" folder of your data directory. Be sure to select the template that matches your selected LoRA model to get the best results.',
        });

        tippy("#inference_temperature", {
         placement: 'right',
         delay: [500, 0],
         content: 'Controls randomness: Lowering results in less random completions. Higher values (e.g., 1.0) make the model generate more diverse and random outputs. As the temperature approaches zero, the model will become deterministic and repetitive.',
        });

        tippy("#inference_top_p", {
         placement: 'right',
         delay: [500, 0],
         content: 'Controls diversity via nucleus sampling: only the tokens whose cumulative probability exceeds "top_p" are considered. 0.5 means half of all likelihood-weighted options are considered.',
        });

        tippy("#inference_top_k", {
         placement: 'right',
         delay: [500, 0],
         content: 'Controls diversity of the generated text by only considering the "top_k" tokens with the highest probabilities. This method can lead to more focused and coherent outputs by reducing the impact of low probability tokens.',
        });

        tippy("#inference_beams", {
         placement: 'right',
         delay: [500, 0],
         content: 'Number of candidate sequences explored in parallel during text generation using beam search. A higher value increases the chances of finding high-quality, coherent output, but may slow down the generation process.',
        });

        tippy("#inference_repetition_penalty", {
         placement: 'right',
         delay: [500, 0],
         content: 'Applies a penalty to the probability of tokens that have already been generated, discouraging the model from repeating the same words or phrases. The penalty is applied by dividing the token probability by a factor based on the number of times the token has appeared in the generated text.',
        });

        tippy("#inference_max_new_tokens", {
         placement: 'right',
         delay: [500, 0],
         content: 'Limits the maximum number of tokens generated in a single iteration.',
        });

        tippy("#inference_stream_output", {
         placement: 'right',
         delay: [500, 0],
         content: 'When enabled, generated text will be displayed in real-time as it is being produced by the model, allowing you to observe the text generation process as it unfolds.',
        });

      }, 100);

      // Show/hide generate and save button base on the state.
      setTimeout(function () {
        // Make the '#inference_output > .wrap' element appear
        document.getElementById("inference_stop_btn").click();

        setTimeout(function () {
          const output_wrap_element = document.querySelector(
            "#inference_output > .wrap"
          );
          function handle_output_wrap_element_class_change() {
            if (Array.from(output_wrap_element.classList).includes("hide")) {
              document.getElementById("inference_generate_btn").style.display =
                "block";
              document.getElementById("inference_stop_btn").style.display = "none";
            } else {
              document.getElementById("inference_generate_btn").style.display =
                "none";
              document.getElementById("inference_stop_btn").style.display = "block";
            }
          }
          new MutationObserver(function (mutationsList, observer) {
            handle_output_wrap_element_class_change();
          }).observe(output_wrap_element, {
            attributes: true,
            attributeFilter: ["class"],
          });
          handle_output_wrap_element_class_change();
        }, 500);
      }, 0);
    }
    """)