File size: 8,456 Bytes
ed0dca2
a40632d
010d9e4
bdf3e70
38d7f73
 
2f1a468
 
bfbbbea
 
8ec03fe
1c03d71
bdf3e70
30da7cc
8ec03fe
9e53c43
b1cf10f
30da7cc
b1cf10f
8ec03fe
a40632d
aeb447f
a52308c
6664e37
aeb447f
 
 
 
 
 
6664e37
aeb447f
 
 
 
 
 
6664e37
aeb447f
 
 
 
 
058f0e8
2f1a468
 
 
 
 
 
 
 
 
 
 
 
30da7cc
2f1a468
 
 
 
 
 
 
 
42cd34a
bdf3e70
2f1a468
 
ed0dca2
52589e7
7b89069
 
17dfda2
 
7b89069
cfbd01a
 
e590182
d8b753d
e590182
 
 
7b89069
 
 
 
 
e590182
b192cf6
e590182
b192cf6
e590182
b192cf6
e590182
 
 
 
 
 
 
 
7b89069
 
 
 
 
 
 
44d732c
7b89069
 
 
e590182
7b89069
e590182
 
7b89069
 
 
 
 
 
 
 
 
e590182
 
 
 
d8b753d
7b89069
 
02fc014
d8b753d
7b89069
 
e590182
 
 
7b89069
e590182
 
7b89069
 
 
 
 
 
 
 
 
 
e590182
2f1a468
7b89069
 
2f1a468
7b89069
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f1a468
7b89069
2f1a468
17dfda2
 
52589e7
 
 
 
17dfda2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
##################################### Imports ######################################
# Generic imports
import spaces
import gradio as gr
import json

# Specialized imports
#from utilities.modeling import modeling

# server import
from server import submit_weights #, train_model, submit_weights


# Module imports
from utilities.setup import get_files
from utilities.templates import prompt_template

########################### Global objects and functions ###########################

conf = get_files.json_cfg()

class update_visibility:

    def textbox_vis(radio):
        value = radio
        if value == "Hugging Face Hub Dataset":
            return gr.Dropdown(visible=bool(1))
        else:
            return gr.Dropdown(visible=bool(0))
    
    def textbox_button_vis(radio):
        value = radio
        if value == "Hugging Face Hub Dataset":
            return gr.Button(visible=bool(1))
        else:
            return gr.Button(visible=bool(0))
    
    def upload_vis(radio):
        value = radio
        if value == "Upload Your Own":
            return gr.UploadButton(visible=bool(1)) #make it visible
        else:
            return gr.UploadButton(visible=bool(0))
@spaces.GPU
def train(model_name, 
          inject_prompt, 
          dataset_predefined,
          peft,
          sft,
          max_seq_length,
          random_seed,
          num_epochs,
          max_steps,
          data_field,
          repository,
          model_out_name):
    """The model call"""

    # Get models
    # trainer = modeling(model_name, max_seq_length, random_seed,
    #                    peft, sft, dataset, data_field)
    # trainer_stats = trainer.train()

    # Return outputs of training.
    
    return f"Hello!! Using model: {model_name} with template: {inject_prompt}"



##################################### App UI #######################################



def main():
    with gr.Blocks() as demo:

        with gr.Tabs():
            with gr.TabItem("About"):
                # About page!!
                gr.Markdown(get_files.load_markdown_file("README.md"))

            with gr.TabItem("Basic Setup"):
                gr.Markdown("# Select Model and Input details")
                # Select Model
                modelnames = conf['model']['choices']
                model_name = gr.Dropdown(label="Supported Models", 
                                         choices=modelnames, 
                                         value=modelnames[0])
                # Select Generic Model parameters
                repository = gr.Textbox(label="Your User Name",
                                        value=conf['model']['general']["repository"])   
                model_out_name = gr.Textbox(label="Your Model Output Name",
                                        value=conf['model']['general']["model_name"])
                hf_token = gr.Textbox(label="Your Huggingface Token",
                                      type='password',
                                      value='')
                

            with gr.TabItem("Upload Data"):
                # Toggle dataset load types
                gr.Markdown("# Dataset Selection and Upload")
                
                dataset_choice = gr.Radio(label="Choose Dataset", 
                                          choices=["Hugging Face Hub Dataset", "Upload Your Own"], 
                                          value="Hugging Face Hub Dataset")
                dataset_predefined = gr.Textbox(label="Hugging Face Hub Training Dataset", 
                                                value='yahma/alpaca-cleaned', 
                                                visible=True)
                dataset_predefined_load = gr.Button("Upload Dataset (.csv, .jsonl, or .txt)")
        
                dataset_uploaded_load = gr.UploadButton(label="Upload Dataset (.csv, .jsonl, or .txt)", 
                                                 file_types=[".csv",".jsonl", ".txt"], 
                                                 visible=False)
                # Safety output to show if upload succeeded.
                data_snippet = gr.Markdown()

                # Visibility toggler
                dataset_choice.change(update_visibility.textbox_vis, 
                                      dataset_choice, 
                                      dataset_predefined)
                dataset_choice.change(update_visibility.upload_vis, 
                                      dataset_choice,
                                      dataset_uploaded_load)
                dataset_choice.change(update_visibility.textbox_button_vis,
                                      dataset_choice,
                                      dataset_predefined_load)
                # Prompt template
                inject_prompt = gr.Textbox(label="Prompt Template", 
                                             value=prompt_template())
                # Dataset buttons
                dataset_predefined_load.click(fn=get_files.predefined_dataset,
                                          inputs=dataset_predefined,
                                          outputs=data_snippet)
        
                dataset_uploaded_load.click(fn=get_files.uploaded_dataset,
                                         inputs=dataset_uploaded_load,
                                         outputs=data_snippet)
         
            
            with gr.TabItem("Train Model"):
                ##### Model Parameter Inputs #####
                gr.Markdown("# Model Parameter Selection")

                # Parameters
                data_field = gr.Textbox(label="Dataset Training Field Name",
                                        value=conf['model']['general']["dataset_text_field"])
                max_seq_length = gr.Textbox(label="Maximum sequence length", 
                                             value=conf['model']['general']["max_seq_length"])
                random_seed = gr.Textbox(label="Seed",
                                        value=conf['model']['general']["seed"])
                num_epochs = gr.Textbox(label="Training Epochs",
                                        value=conf['model']['general']["num_train_epochs"])
                max_steps = gr.Textbox(label="Maximum steps",
                                        value=conf['model']['general']["max_steps"])    
        
                # Hyperparameters (allow selection, but hide in accordion.)
                with gr.Accordion("Advanced Tuning", open=False):
        
                    sftparams = conf['model']['general']
                    # accordion container content
                    dict_string = json.dumps(dict(conf['model']['peft']), indent=4)
                    peft = gr.Textbox(label="PEFT Parameters (json)", value=dict_string)
                    
                    dict_string = json.dumps(dict(conf['model']['sft']), indent=4)
                    sft = gr.Textbox(label="SFT Parameters (json)", value=dict_string)            
                
                ##### Execution #####
            
                # Setup buttons
                tune_btn = gr.Button("Start Fine Tuning")
                gr.Markdown("### Model Progress")
                # Text output (for now)
                output = gr.Textbox(label="Output") 
                
                
                # Data retrieval
                
                
                # Execute buttons
                tune_btn.click(fn=train, 
                               inputs=[model_name, 
                                       inject_prompt, 
                                       dataset_predefined,
                                       peft,
                                       sft,
                                       max_seq_length,
                                       random_seed,
                                       num_epochs,
                                       max_steps,
                                       data_field,
                                       repository,
                                       model_out_name
                                      ],
                               outputs=output)
                # stop button
        
                # submit button
        
        # Launch baby
        demo.launch()

##################################### Launch #######################################

if __name__ == "__main__":
    main()