File size: 3,396 Bytes
ac66ae2 ad99c34 7465957 cbbb9fd df44c11 083fde1 ada7179 251d88f c8534fb e92ef1c df44c11 38576e5 df44c11 467c88a df44c11 38576e5 76019db df44c11 38576e5 df44c11 065dd39 76019db df9a90e 483c87c df44c11 ada7179 df44c11 ada7179 df44c11 ada7179 c8534fb 251d88f df44c11 92146e5 ada7179 251d88f cbbb9fd 251d88f cbbb9fd ada7179 cbbb9fd ada7179 0fb434b 251d88f ada7179 2b03f9f ada7179 2b03f9f e69ea59 df44c11 e92ef1c ada7179 065dd39 a184b8b 2371111 835fa92 13e776a 94ca6da 13e776a 2371111 083fde1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import gradio as gr
import os, torch
from datasets import load_dataset
from huggingface_hub import HfApi, login
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# Fine-tune on NVidia A10G Large (sleep after 1 hour)
profile = "bstraehle"
action_1 = "Fine-tune pre-trained model"
action_2 = "Prompt fine-tuned model"
system_prompt = "You are a text to SQL query translator. Given a question in English, generate a SQL query based on the provided SCHEMA. Do not generate any additional text. SCHEMA: {schema}"
user_prompt = "What is the total trade value and average price for each trader and stock in the trade_history table?"
schema = "CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(255), price DECIMAL(5,2), quantity INT, trade_time TIMESTAMP);"
base_model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
dataset = "gretelai/synthetic_text_to_sql"
def prompt_model(model_id, system_prompt, user_prompt, schema):
pipe = pipeline("text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto")
messages = [
{"role": "system", "content": system_prompt.format(schema=schema)},
{"role": "user", "content": user_prompt},
{"role": "assistant", "content": ""}
]
output = pipe(messages)
return output[0]["generated_text"][-1]["content"]
def fine_tune_model(base_model_id):
tokenizer = download_model(base_model_id)
fine_tuned_model_id = upload_model(base_model_id, tokenizer)
return fine_tuned_model_id
def download_model(base_model_id):
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
model = AutoModelForCausalLM.from_pretrained(base_model_id)
model.save_pretrained(base_model_id)
return tokenizer
#def download_dataset(dataset):
# ds = load_dataset(dataset)
# return ""
def upload_model(base_model_id, tokenizer):
fine_tuned_model_id = replace_profile(base_model_id)
login(token=os.environ["HF_TOKEN"])
api = HfApi()
api.create_repo(repo_id=fine_tuned_model_id)
api.upload_folder(
folder_path=base_model_id,
repo_id=fine_tuned_model_id)
)
tokenizer.push_to_hub(fine_tuned_model_id)
return fine_tuned_model_id
def replace_profile(base_model_id):
model_id = base_model_id[base_model_id.rfind('/')+1:]
return f"{profile}/{model_id}"
def process(action, base_model_id, dataset, system_prompt, user_prompt, schema):
if action == action_1:
result = fine_tune_model(base_model_id)
elif action == action_2:
fine_tuned_model_id = replace_profile(base_model_id)
result = prompt_model(fine_tuned_model_id, system_prompt, user_prompt, schema)
return result
demo = gr.Interface(fn=process,
inputs=[gr.Radio([action_1, action_2], label = "Action", value = action_1),
gr.Textbox(label = "Base Model ID", value = base_model_id, lines = 1),
gr.Textbox(label = "Dataset", value = dataset, lines = 1),
gr.Textbox(label = "System Prompt", value = system_prompt, lines = 2),
gr.Textbox(label = "User Prompt", value = user_prompt, lines = 2),
gr.Textbox(label = "Schema", value = schema, lines = 2)],
outputs=[gr.Textbox(label = "Completion")])
demo.launch() |