Spaces:
Sleeping
Sleeping
##################################### Imports ###################################### | |
# Generic imports | |
import gradio as gr | |
import json | |
# Specialized imports | |
#from utilities.modeling import modeling | |
from datasets import load_dataset | |
# Module imports | |
from utilities.setup import get_json_cfg | |
from utilities.templates import prompt_template | |
########################### Global objects and functions ########################### | |
conf = get_json_cfg() | |
class update_visibility: | |
def textbox_vis(radio): | |
value = radio | |
if value == "Hugging Face Hub Dataset": | |
return gr.Dropdown(visible=bool(1)) | |
else: | |
return gr.Dropdown(visible=bool(0)) | |
def textbox_button_vis(radio): | |
value = radio | |
if value == "Hugging Face Hub Dataset": | |
return gr.Button(visible=bool(1)) | |
else: | |
return gr.Button(visible=bool(0)) | |
def upload_vis(radio): | |
value = radio | |
if value == "Upload Your Own": | |
return gr.UploadButton(visible=bool(1)) #make it visible | |
else: | |
return gr.UploadButton(visible=bool(0)) | |
class get_datasets: | |
def predefined_dataset(dataset_name): | |
global dataset # bad practice, I know... But just bear with me. Will later update to state dict. | |
dataset = load_dataset(dataset_name, split = "train") | |
return 'Successfully loaded dataset' | |
def uploaded_dataset(file): | |
global dataset # bad practice, I know... But just bear with me. Will later update to state dict. | |
dataset = [] | |
if file is None: | |
return "File not found. Please upload the file again." | |
try: | |
with open(file,'r') as file: | |
for line in file: | |
dataset.append(json.loads(line.strip())) | |
return "File retrieved." | |
except FileNotFoundError: | |
return "File not found. Please upload the file again." | |
def about_page(): | |
return "## About\n\nThis is an application for uploading datasets. You can upload files in .csv, .jsonl, or .txt format. The app will process the file and provide feedback." | |
def show_about(): | |
return "## About\n\nThis is an application for uploading datasets. You can upload files in .csv, .jsonl, or .txt format. The app will process the file and provide feedback." | |
def show_upload(): | |
return "## Upload\n\nUse the button below to upload your dataset." | |
def train(model_name, | |
inject_prompt, | |
dataset_predefined, | |
peft, | |
sft, | |
max_seq_length, | |
random_seed, | |
num_epochs, | |
max_steps, | |
data_field, | |
repository, | |
model_out_name): | |
"""The model call""" | |
# Get models | |
# trainer = modeling(model_name, max_seq_length, random_seed, | |
# peft, sft, dataset, data_field) | |
# trainer_stats = trainer.train() | |
# Return outputs of training. | |
return f"Hello!! Using model: {model_name} with template: {inject_prompt}" | |
def submit_weights(model, repository, model_out_name, token): | |
"""submits model to repository""" | |
repo = repository + '/' + model_out_name | |
model.push_to_hub(repo, token = token) | |
tokenizer.push_to_hub(repo, token = token) | |
return 0 | |
##################################### App UI ####################################### | |
def main(): | |
with gr.Blocks() as demo: | |
with gr.Tabs(): | |
with gr.TabItem("About"): | |
gr.Markdown("## About\n\nThis is an application for uploading datasets. You can upload files in .csv, .jsonl, or .txt format. The app will process the file and provide feedback.") | |
with gr.TabItem("Train Model"): | |
##### Title Block ##### | |
gr.Markdown("# SLM Instruction Tuning with Unsloth") | |
##### Initial Model Inputs ##### | |
gr.Markdown("### Model Inputs") | |
# Select Model | |
modelnames = conf['model']['choices'] | |
model_name = gr.Dropdown(label="Supported Models", | |
choices=modelnames, | |
value=modelnames[0]) | |
# Prompt template | |
inject_prompt = gr.Textbox(label="Prompt Template", | |
value=prompt_template()) | |
# Dataset choice | |
dataset_choice = gr.Radio(label="Choose Dataset", | |
choices=["Hugging Face Hub Dataset", "Upload Your Own"], | |
value="Hugging Face Hub Dataset") | |
dataset_predefined = gr.Textbox(label="Hugging Face Hub Training Dataset", | |
value='yahma/alpaca-cleaned', | |
visible=True) | |
dataset_predefined_load = gr.Button("Upload Dataset (.csv, .jsonl, or .txt)") | |
dataset_uploaded_load = gr.UploadButton(label="Upload Dataset (.csv, .jsonl, or .txt)", | |
file_types=[".csv",".jsonl", ".txt"], | |
visible=False) | |
data_snippet = gr.Markdown() | |
dataset_choice.change(update_visibility.textbox_vis, | |
dataset_choice, | |
dataset_predefined) | |
dataset_choice.change(update_visibility.upload_vis, | |
dataset_choice, | |
dataset_uploaded_load) | |
dataset_choice.change(update_visibility.textbox_button_vis, | |
dataset_choice, | |
dataset_predefined_load) | |
# Dataset button | |
dataset_predefined_load.click(fn=get_datasets.predefined_dataset, | |
inputs=dataset_predefined, | |
outputs=data_snippet) | |
dataset_uploaded_load.click(fn=get_datasets.uploaded_dataset, | |
inputs=dataset_uploaded_load, | |
outputs=data_snippet) | |
##### Model Parameter Inputs ##### | |
gr.Markdown("### Model Parameter Selection") | |
# Parameters | |
data_field = gr.Textbox(label="Dataset Training Field Name", | |
value=conf['model']['general']["dataset_text_field"]) | |
max_seq_length = gr.Textbox(label="Maximum sequence length", | |
value=conf['model']['general']["max_seq_length"]) | |
random_seed = gr.Textbox(label="Seed", | |
value=conf['model']['general']["seed"]) | |
num_epochs = gr.Textbox(label="Training Epochs", | |
value=conf['model']['general']["num_train_epochs"]) | |
max_steps = gr.Textbox(label="Maximum steps", | |
value=conf['model']['general']["max_steps"]) | |
repository = gr.Textbox(label="Repository Name", | |
value=conf['model']['general']["repository"]) | |
model_out_name = gr.Textbox(label="Model Output Name", | |
value=conf['model']['general']["model_name"]) | |
# Hyperparameters (allow selection, but hide in accordion.) | |
with gr.Accordion("Advanced Tuning", open=False): | |
sftparams = conf['model']['general'] | |
# accordion container content | |
dict_string = json.dumps(dict(conf['model']['peft']), indent=4) | |
peft = gr.Textbox(label="PEFT Parameters (json)", value=dict_string) | |
dict_string = json.dumps(dict(conf['model']['sft']), indent=4) | |
sft = gr.Textbox(label="SFT Parameters (json)", value=dict_string) | |
##### Execution ##### | |
# Setup buttons | |
tune_btn = gr.Button("Start Fine Tuning") | |
gr.Markdown("### Model Progress") | |
# Text output (for now) | |
output = gr.Textbox(label="Output") | |
# Data retrieval | |
# Execute buttons | |
tune_btn.click(fn=train, | |
inputs=[model_name, | |
inject_prompt, | |
dataset_predefined, | |
peft, | |
sft, | |
max_seq_length, | |
random_seed, | |
num_epochs, | |
max_steps, | |
data_field, | |
repository, | |
model_out_name | |
], | |
outputs=output) | |
# stop button | |
# submit button | |
# Launch baby | |
demo.launch() | |
##################################### Launch ####################################### | |
if __name__ == "__main__": | |
main() |