tykiww's picture
Update app.py
cfbd01a verified
raw
history blame
9.79 kB
##################################### Imports ######################################
# Generic imports
import gradio as gr
import json
# Specialized imports
#from utilities.modeling import modeling
from datasets import load_dataset
# Module imports
from utilities.setup import get_json_cfg
from utilities.templates import prompt_template
########################### Global objects and functions ###########################
conf = get_json_cfg()
class update_visibility:
def textbox_vis(radio):
value = radio
if value == "Hugging Face Hub Dataset":
return gr.Dropdown(visible=bool(1))
else:
return gr.Dropdown(visible=bool(0))
def textbox_button_vis(radio):
value = radio
if value == "Hugging Face Hub Dataset":
return gr.Button(visible=bool(1))
else:
return gr.Button(visible=bool(0))
def upload_vis(radio):
value = radio
if value == "Upload Your Own":
return gr.UploadButton(visible=bool(1)) #make it visible
else:
return gr.UploadButton(visible=bool(0))
class get_datasets:
def predefined_dataset(dataset_name):
global dataset # bad practice, I know... But just bear with me. Will later update to state dict.
dataset = load_dataset(dataset_name, split = "train")
return 'Successfully loaded dataset'
def uploaded_dataset(file):
global dataset # bad practice, I know... But just bear with me. Will later update to state dict.
dataset = []
if file is None:
return "File not found. Please upload the file again."
try:
with open(file,'r') as file:
for line in file:
dataset.append(json.loads(line.strip()))
return "File retrieved."
except FileNotFoundError:
return "File not found. Please upload the file again."
def about_page():
return "## About\n\nThis is an application for uploading datasets. You can upload files in .csv, .jsonl, or .txt format. The app will process the file and provide feedback."
def show_about():
return "## About\n\nThis is an application for uploading datasets. You can upload files in .csv, .jsonl, or .txt format. The app will process the file and provide feedback."
def show_upload():
return "## Upload\n\nUse the button below to upload your dataset."
def train(model_name,
inject_prompt,
dataset_predefined,
peft,
sft,
max_seq_length,
random_seed,
num_epochs,
max_steps,
data_field,
repository,
model_out_name):
"""The model call"""
# Get models
# trainer = modeling(model_name, max_seq_length, random_seed,
# peft, sft, dataset, data_field)
# trainer_stats = trainer.train()
# Return outputs of training.
return f"Hello!! Using model: {model_name} with template: {inject_prompt}"
def submit_weights(model, repository, model_out_name, token):
"""submits model to repository"""
repo = repository + '/' + model_out_name
model.push_to_hub(repo, token = token)
tokenizer.push_to_hub(repo, token = token)
return 0
##################################### App UI #######################################
def main():
with gr.Blocks() as demo:
with gr.Tabs():
with gr.TabItem("About"):
gr.Markdown("## About\n\nThis is an application for uploading datasets. You can upload files in .csv, .jsonl, or .txt format. The app will process the file and provide feedback.")
with gr.TabItem("Train Model"):
##### Title Block #####
gr.Markdown("# SLM Instruction Tuning with Unsloth")
##### Initial Model Inputs #####
gr.Markdown("### Model Inputs")
# Select Model
modelnames = conf['model']['choices']
model_name = gr.Dropdown(label="Supported Models",
choices=modelnames,
value=modelnames[0])
# Prompt template
inject_prompt = gr.Textbox(label="Prompt Template",
value=prompt_template())
# Dataset choice
dataset_choice = gr.Radio(label="Choose Dataset",
choices=["Hugging Face Hub Dataset", "Upload Your Own"],
value="Hugging Face Hub Dataset")
dataset_predefined = gr.Textbox(label="Hugging Face Hub Training Dataset",
value='yahma/alpaca-cleaned',
visible=True)
dataset_predefined_load = gr.Button("Upload Dataset (.csv, .jsonl, or .txt)")
dataset_uploaded_load = gr.UploadButton(label="Upload Dataset (.csv, .jsonl, or .txt)",
file_types=[".csv",".jsonl", ".txt"],
visible=False)
data_snippet = gr.Markdown()
dataset_choice.change(update_visibility.textbox_vis,
dataset_choice,
dataset_predefined)
dataset_choice.change(update_visibility.upload_vis,
dataset_choice,
dataset_uploaded_load)
dataset_choice.change(update_visibility.textbox_button_vis,
dataset_choice,
dataset_predefined_load)
# Dataset button
dataset_predefined_load.click(fn=get_datasets.predefined_dataset,
inputs=dataset_predefined,
outputs=data_snippet)
dataset_uploaded_load.click(fn=get_datasets.uploaded_dataset,
inputs=dataset_uploaded_load,
outputs=data_snippet)
##### Model Parameter Inputs #####
gr.Markdown("### Model Parameter Selection")
# Parameters
data_field = gr.Textbox(label="Dataset Training Field Name",
value=conf['model']['general']["dataset_text_field"])
max_seq_length = gr.Textbox(label="Maximum sequence length",
value=conf['model']['general']["max_seq_length"])
random_seed = gr.Textbox(label="Seed",
value=conf['model']['general']["seed"])
num_epochs = gr.Textbox(label="Training Epochs",
value=conf['model']['general']["num_train_epochs"])
max_steps = gr.Textbox(label="Maximum steps",
value=conf['model']['general']["max_steps"])
repository = gr.Textbox(label="Repository Name",
value=conf['model']['general']["repository"])
model_out_name = gr.Textbox(label="Model Output Name",
value=conf['model']['general']["model_name"])
# Hyperparameters (allow selection, but hide in accordion.)
with gr.Accordion("Advanced Tuning", open=False):
sftparams = conf['model']['general']
# accordion container content
dict_string = json.dumps(dict(conf['model']['peft']), indent=4)
peft = gr.Textbox(label="PEFT Parameters (json)", value=dict_string)
dict_string = json.dumps(dict(conf['model']['sft']), indent=4)
sft = gr.Textbox(label="SFT Parameters (json)", value=dict_string)
##### Execution #####
# Setup buttons
tune_btn = gr.Button("Start Fine Tuning")
gr.Markdown("### Model Progress")
# Text output (for now)
output = gr.Textbox(label="Output")
# Data retrieval
# Execute buttons
tune_btn.click(fn=train,
inputs=[model_name,
inject_prompt,
dataset_predefined,
peft,
sft,
max_seq_length,
random_seed,
num_epochs,
max_steps,
data_field,
repository,
model_out_name
],
outputs=output)
# stop button
# submit button
# Launch baby
demo.launch()
##################################### Launch #######################################
if __name__ == "__main__":
main()