tykiww's picture
Update app.py
bfbbbea verified
raw
history blame
8.44 kB
##################################### Imports ######################################
# Generic imports
import gradio as gr
import json
# Specialized imports
#from utilities.modeling import modeling
# server import
from server import get_files, submit_weights #, train_model, submit_weights
# Module imports
from utilities.setup import get_json_cfg
from utilities.templates import prompt_template
########################### Global objects and functions ###########################
conf = get_json_cfg()
class update_visibility:
def textbox_vis(radio):
value = radio
if value == "Hugging Face Hub Dataset":
return gr.Dropdown(visible=bool(1))
else:
return gr.Dropdown(visible=bool(0))
def textbox_button_vis(radio):
value = radio
if value == "Hugging Face Hub Dataset":
return gr.Button(visible=bool(1))
else:
return gr.Button(visible=bool(0))
def upload_vis(radio):
value = radio
if value == "Upload Your Own":
return gr.UploadButton(visible=bool(1)) #make it visible
else:
return gr.UploadButton(visible=bool(0))
def train(model_name,
inject_prompt,
dataset_predefined,
peft,
sft,
max_seq_length,
random_seed,
num_epochs,
max_steps,
data_field,
repository,
model_out_name):
"""The model call"""
# Get models
# trainer = modeling(model_name, max_seq_length, random_seed,
# peft, sft, dataset, data_field)
# trainer_stats = trainer.train()
# Return outputs of training.
return f"Hello!! Using model: {model_name} with template: {inject_prompt}"
##################################### App UI #######################################
def main():
with gr.Blocks() as demo:
with gr.Tabs():
with gr.TabItem("About"):
# About page!!
gr.Markdown(get_files.load_markdown_file("README.md"))
with gr.TabItem("Basic Setup"):
gr.Markdown("# Select Model and Input details")
# Select Model
modelnames = conf['model']['choices']
model_name = gr.Dropdown(label="Supported Models",
choices=modelnames,
value=modelnames[0])
# Select Generic Model parameters
repository = gr.Textbox(label="Your User Name",
value=conf['model']['general']["repository"])
model_out_name = gr.Textbox(label="Your Model Output Name",
value=conf['model']['general']["model_name"])
hf_token = gr.Textbox(label="Your Huggingface Token",
type='password',
value='')
with gr.TabItem("Upload Data"):
# Toggle dataset load types
gr.Markdown("# Dataset Selection and Upload")
dataset_choice = gr.Radio(label="Choose Dataset",
choices=["Hugging Face Hub Dataset", "Upload Your Own"],
value="Hugging Face Hub Dataset")
dataset_predefined = gr.Textbox(label="Hugging Face Hub Training Dataset",
value='yahma/alpaca-cleaned',
visible=True)
dataset_predefined_load = gr.Button("Upload Dataset (.csv, .jsonl, or .txt)")
dataset_uploaded_load = gr.UploadButton(label="Upload Dataset (.csv, .jsonl, or .txt)",
file_types=[".csv",".jsonl", ".txt"],
visible=False)
# Safety output to show if upload succeeded.
data_snippet = gr.Markdown()
# Visibility toggler
dataset_choice.change(update_visibility.textbox_vis,
dataset_choice,
dataset_predefined)
dataset_choice.change(update_visibility.upload_vis,
dataset_choice,
dataset_uploaded_load)
dataset_choice.change(update_visibility.textbox_button_vis,
dataset_choice,
dataset_predefined_load)
# Prompt template
inject_prompt = gr.Textbox(label="Prompt Template",
value=prompt_template())
# Dataset buttons
dataset_predefined_load.click(fn=get_files.predefined_dataset,
inputs=dataset_predefined,
outputs=data_snippet)
dataset_uploaded_load.click(fn=get_files.uploaded_dataset,
inputs=dataset_uploaded_load,
outputs=data_snippet)
with gr.TabItem("Train Model"):
##### Model Parameter Inputs #####
gr.Markdown("# Model Parameter Selection")
# Parameters
data_field = gr.Textbox(label="Dataset Training Field Name",
value=conf['model']['general']["dataset_text_field"])
max_seq_length = gr.Textbox(label="Maximum sequence length",
value=conf['model']['general']["max_seq_length"])
random_seed = gr.Textbox(label="Seed",
value=conf['model']['general']["seed"])
num_epochs = gr.Textbox(label="Training Epochs",
value=conf['model']['general']["num_train_epochs"])
max_steps = gr.Textbox(label="Maximum steps",
value=conf['model']['general']["max_steps"])
# Hyperparameters (allow selection, but hide in accordion.)
with gr.Accordion("Advanced Tuning", open=False):
sftparams = conf['model']['general']
# accordion container content
dict_string = json.dumps(dict(conf['model']['peft']), indent=4)
peft = gr.Textbox(label="PEFT Parameters (json)", value=dict_string)
dict_string = json.dumps(dict(conf['model']['sft']), indent=4)
sft = gr.Textbox(label="SFT Parameters (json)", value=dict_string)
##### Execution #####
# Setup buttons
tune_btn = gr.Button("Start Fine Tuning")
gr.Markdown("### Model Progress")
# Text output (for now)
output = gr.Textbox(label="Output")
# Data retrieval
# Execute buttons
tune_btn.click(fn=train,
inputs=[model_name,
inject_prompt,
dataset_predefined,
peft,
sft,
max_seq_length,
random_seed,
num_epochs,
max_steps,
data_field,
repository,
model_out_name
],
outputs=output)
# stop button
# submit button
# Launch baby
demo.launch()
##################################### Launch #######################################
if __name__ == "__main__":
main()