Spaces:
Sleeping
Sleeping
File size: 8,456 Bytes
ed0dca2 a40632d 010d9e4 bdf3e70 38d7f73 2f1a468 bfbbbea 8ec03fe 1c03d71 bdf3e70 30da7cc 8ec03fe 9e53c43 b1cf10f 30da7cc b1cf10f 8ec03fe a40632d aeb447f a52308c 6664e37 aeb447f 6664e37 aeb447f 6664e37 aeb447f 058f0e8 2f1a468 30da7cc 2f1a468 42cd34a bdf3e70 2f1a468 ed0dca2 52589e7 7b89069 17dfda2 7b89069 cfbd01a e590182 d8b753d e590182 7b89069 e590182 b192cf6 e590182 b192cf6 e590182 b192cf6 e590182 7b89069 44d732c 7b89069 e590182 7b89069 e590182 7b89069 e590182 d8b753d 7b89069 02fc014 d8b753d 7b89069 e590182 7b89069 e590182 7b89069 e590182 2f1a468 7b89069 2f1a468 7b89069 2f1a468 7b89069 2f1a468 17dfda2 52589e7 17dfda2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
##################################### Imports ######################################
# Generic imports
import spaces
import gradio as gr
import json
# Specialized imports
#from utilities.modeling import modeling
# server import
from server import submit_weights #, train_model, submit_weights
# Module imports
from utilities.setup import get_files
from utilities.templates import prompt_template
########################### Global objects and functions ###########################
conf = get_files.json_cfg()
class update_visibility:
def textbox_vis(radio):
value = radio
if value == "Hugging Face Hub Dataset":
return gr.Dropdown(visible=bool(1))
else:
return gr.Dropdown(visible=bool(0))
def textbox_button_vis(radio):
value = radio
if value == "Hugging Face Hub Dataset":
return gr.Button(visible=bool(1))
else:
return gr.Button(visible=bool(0))
def upload_vis(radio):
value = radio
if value == "Upload Your Own":
return gr.UploadButton(visible=bool(1)) #make it visible
else:
return gr.UploadButton(visible=bool(0))
@spaces.GPU
def train(model_name,
inject_prompt,
dataset_predefined,
peft,
sft,
max_seq_length,
random_seed,
num_epochs,
max_steps,
data_field,
repository,
model_out_name):
"""The model call"""
# Get models
# trainer = modeling(model_name, max_seq_length, random_seed,
# peft, sft, dataset, data_field)
# trainer_stats = trainer.train()
# Return outputs of training.
return f"Hello!! Using model: {model_name} with template: {inject_prompt}"
##################################### App UI #######################################
def main():
with gr.Blocks() as demo:
with gr.Tabs():
with gr.TabItem("About"):
# About page!!
gr.Markdown(get_files.load_markdown_file("README.md"))
with gr.TabItem("Basic Setup"):
gr.Markdown("# Select Model and Input details")
# Select Model
modelnames = conf['model']['choices']
model_name = gr.Dropdown(label="Supported Models",
choices=modelnames,
value=modelnames[0])
# Select Generic Model parameters
repository = gr.Textbox(label="Your User Name",
value=conf['model']['general']["repository"])
model_out_name = gr.Textbox(label="Your Model Output Name",
value=conf['model']['general']["model_name"])
hf_token = gr.Textbox(label="Your Huggingface Token",
type='password',
value='')
with gr.TabItem("Upload Data"):
# Toggle dataset load types
gr.Markdown("# Dataset Selection and Upload")
dataset_choice = gr.Radio(label="Choose Dataset",
choices=["Hugging Face Hub Dataset", "Upload Your Own"],
value="Hugging Face Hub Dataset")
dataset_predefined = gr.Textbox(label="Hugging Face Hub Training Dataset",
value='yahma/alpaca-cleaned',
visible=True)
dataset_predefined_load = gr.Button("Upload Dataset (.csv, .jsonl, or .txt)")
dataset_uploaded_load = gr.UploadButton(label="Upload Dataset (.csv, .jsonl, or .txt)",
file_types=[".csv",".jsonl", ".txt"],
visible=False)
# Safety output to show if upload succeeded.
data_snippet = gr.Markdown()
# Visibility toggler
dataset_choice.change(update_visibility.textbox_vis,
dataset_choice,
dataset_predefined)
dataset_choice.change(update_visibility.upload_vis,
dataset_choice,
dataset_uploaded_load)
dataset_choice.change(update_visibility.textbox_button_vis,
dataset_choice,
dataset_predefined_load)
# Prompt template
inject_prompt = gr.Textbox(label="Prompt Template",
value=prompt_template())
# Dataset buttons
dataset_predefined_load.click(fn=get_files.predefined_dataset,
inputs=dataset_predefined,
outputs=data_snippet)
dataset_uploaded_load.click(fn=get_files.uploaded_dataset,
inputs=dataset_uploaded_load,
outputs=data_snippet)
with gr.TabItem("Train Model"):
##### Model Parameter Inputs #####
gr.Markdown("# Model Parameter Selection")
# Parameters
data_field = gr.Textbox(label="Dataset Training Field Name",
value=conf['model']['general']["dataset_text_field"])
max_seq_length = gr.Textbox(label="Maximum sequence length",
value=conf['model']['general']["max_seq_length"])
random_seed = gr.Textbox(label="Seed",
value=conf['model']['general']["seed"])
num_epochs = gr.Textbox(label="Training Epochs",
value=conf['model']['general']["num_train_epochs"])
max_steps = gr.Textbox(label="Maximum steps",
value=conf['model']['general']["max_steps"])
# Hyperparameters (allow selection, but hide in accordion.)
with gr.Accordion("Advanced Tuning", open=False):
sftparams = conf['model']['general']
# accordion container content
dict_string = json.dumps(dict(conf['model']['peft']), indent=4)
peft = gr.Textbox(label="PEFT Parameters (json)", value=dict_string)
dict_string = json.dumps(dict(conf['model']['sft']), indent=4)
sft = gr.Textbox(label="SFT Parameters (json)", value=dict_string)
##### Execution #####
# Setup buttons
tune_btn = gr.Button("Start Fine Tuning")
gr.Markdown("### Model Progress")
# Text output (for now)
output = gr.Textbox(label="Output")
# Data retrieval
# Execute buttons
tune_btn.click(fn=train,
inputs=[model_name,
inject_prompt,
dataset_predefined,
peft,
sft,
max_seq_length,
random_seed,
num_epochs,
max_steps,
data_field,
repository,
model_out_name
],
outputs=output)
# stop button
# submit button
# Launch baby
demo.launch()
##################################### Launch #######################################
if __name__ == "__main__":
main() |