##################################### Imports ###################################### # Generic imports import gradio as gr import json # Specialized imports #from utilities.modeling import modeling # server import from server import get_files, submit_weights #, train_model, submit_weights # Module imports from utilities.setup import get_json_cfg from utilities.templates import prompt_template ########################### Global objects and functions ########################### conf = get_json_cfg() class update_visibility: def textbox_vis(radio): value = radio if value == "Hugging Face Hub Dataset": return gr.Dropdown(visible=bool(1)) else: return gr.Dropdown(visible=bool(0)) def textbox_button_vis(radio): value = radio if value == "Hugging Face Hub Dataset": return gr.Button(visible=bool(1)) else: return gr.Button(visible=bool(0)) def upload_vis(radio): value = radio if value == "Upload Your Own": return gr.UploadButton(visible=bool(1)) #make it visible else: return gr.UploadButton(visible=bool(0)) def train(model_name, inject_prompt, dataset_predefined, peft, sft, max_seq_length, random_seed, num_epochs, max_steps, data_field, repository, model_out_name): """The model call""" # Get models # trainer = modeling(model_name, max_seq_length, random_seed, # peft, sft, dataset, data_field) # trainer_stats = trainer.train() # Return outputs of training. return f"Hello!! Using model: {model_name} with template: {inject_prompt}" ##################################### App UI ####################################### def main(): with gr.Blocks() as demo: with gr.Tabs(): with gr.TabItem("About"): # About page!! gr.Markdown(get_files.load_markdown_file("README.md")) with gr.TabItem("Basic Setup"): gr.Markdown("# Select Model and Input details") # Select Model modelnames = conf['model']['choices'] model_name = gr.Dropdown(label="Supported Models", choices=modelnames, value=modelnames[0]) # Select Generic Model parameters repository = gr.Textbox(label="Your User Name", value=conf['model']['general']["repository"]) model_out_name = gr.Textbox(label="Your Model Output Name", value=conf['model']['general']["model_name"]) hf_token = gr.Textbox(label="Your Huggingface Token", type='password', value='') with gr.TabItem("Upload Data"): # Toggle dataset load types gr.Markdown("# Dataset Selection and Upload") dataset_choice = gr.Radio(label="Choose Dataset", choices=["Hugging Face Hub Dataset", "Upload Your Own"], value="Hugging Face Hub Dataset") dataset_predefined = gr.Textbox(label="Hugging Face Hub Training Dataset", value='yahma/alpaca-cleaned', visible=True) dataset_predefined_load = gr.Button("Upload Dataset (.csv, .jsonl, or .txt)") dataset_uploaded_load = gr.UploadButton(label="Upload Dataset (.csv, .jsonl, or .txt)", file_types=[".csv",".jsonl", ".txt"], visible=False) # Safety output to show if upload succeeded. data_snippet = gr.Markdown() # Visibility toggler dataset_choice.change(update_visibility.textbox_vis, dataset_choice, dataset_predefined) dataset_choice.change(update_visibility.upload_vis, dataset_choice, dataset_uploaded_load) dataset_choice.change(update_visibility.textbox_button_vis, dataset_choice, dataset_predefined_load) # Prompt template inject_prompt = gr.Textbox(label="Prompt Template", value=prompt_template()) # Dataset buttons dataset_predefined_load.click(fn=get_files.predefined_dataset, inputs=dataset_predefined, outputs=data_snippet) dataset_uploaded_load.click(fn=get_files.uploaded_dataset, inputs=dataset_uploaded_load, outputs=data_snippet) with gr.TabItem("Train Model"): ##### Model Parameter Inputs ##### gr.Markdown("# Model Parameter Selection") # Parameters data_field = gr.Textbox(label="Dataset Training Field Name", value=conf['model']['general']["dataset_text_field"]) max_seq_length = gr.Textbox(label="Maximum sequence length", value=conf['model']['general']["max_seq_length"]) random_seed = gr.Textbox(label="Seed", value=conf['model']['general']["seed"]) num_epochs = gr.Textbox(label="Training Epochs", value=conf['model']['general']["num_train_epochs"]) max_steps = gr.Textbox(label="Maximum steps", value=conf['model']['general']["max_steps"]) # Hyperparameters (allow selection, but hide in accordion.) with gr.Accordion("Advanced Tuning", open=False): sftparams = conf['model']['general'] # accordion container content dict_string = json.dumps(dict(conf['model']['peft']), indent=4) peft = gr.Textbox(label="PEFT Parameters (json)", value=dict_string) dict_string = json.dumps(dict(conf['model']['sft']), indent=4) sft = gr.Textbox(label="SFT Parameters (json)", value=dict_string) ##### Execution ##### # Setup buttons tune_btn = gr.Button("Start Fine Tuning") gr.Markdown("### Model Progress") # Text output (for now) output = gr.Textbox(label="Output") # Data retrieval # Execute buttons tune_btn.click(fn=train, inputs=[model_name, inject_prompt, dataset_predefined, peft, sft, max_seq_length, random_seed, num_epochs, max_steps, data_field, repository, model_out_name ], outputs=output) # stop button # submit button # Launch baby demo.launch() ##################################### Launch ####################################### if __name__ == "__main__": main()