bstraehle commited on
Commit
ada7179
1 Parent(s): e69ea59

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -28
app.py CHANGED
@@ -4,15 +4,7 @@ from datasets import load_dataset
4
  from huggingface_hub import HfApi, login
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
6
 
7
- # Run on NVidia A10G Large (sleep after 1 hour)
8
-
9
- # Model IDs:
10
- #
11
- # meta-llama/Meta-Llama-3-8B-Instruct
12
-
13
- # Datasets:
14
- #
15
- # gretelai/synthetic_text_to_sql
16
 
17
  profile = "bstraehle"
18
 
@@ -24,7 +16,6 @@ user_prompt = "What is the total trade value and average price for each trader a
24
  schema = "CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(255), price DECIMAL(5,2), quantity INT, trade_time TIMESTAMP);"
25
 
26
  base_model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
27
- fine_tuned_model_id = "bstraehle/Meta-Llama-3-8B-Instruct"
28
  dataset = "gretelai/synthetic_text_to_sql"
29
 
30
  def prompt_model(model_id, system_prompt, user_prompt, schema):
@@ -40,16 +31,16 @@ def prompt_model(model_id, system_prompt, user_prompt, schema):
40
 
41
  return output[0]["generated_text"][-1]["content"]
42
 
43
- def fine_tune_model(model_id):
44
- tokenizer = download_model(model_id)
45
- model_repo_name = upload_model(model_id, tokenizer)
46
 
47
- return model_repo_name
48
 
49
- def download_model(model_id):
50
- tokenizer = AutoTokenizer.from_pretrained(model_id)
51
- model = AutoModelForCausalLM.from_pretrained(model_id)
52
- model.save_pretrained(model_id)
53
 
54
  return tokenizer
55
 
@@ -57,29 +48,31 @@ def download_model(model_id):
57
  # ds = load_dataset(dataset)
58
  # return ""
59
 
60
- def upload_model(model_id, tokenizer):
61
- model_name = model_id[model_id.rfind('/')+1:]
62
- model_repo_name = f"{profile}/{model_name}"
63
 
64
  login(token=os.environ["HF_TOKEN"])
65
 
66
  api = HfApi()
67
- api.create_repo(repo_id=model_repo_name)
68
  api.upload_folder(
69
- folder_path=model_id,
70
- repo_id=model_repo_name
71
  )
72
 
73
- tokenizer.push_to_hub(model_repo_name)
 
 
74
 
75
- return model_repo_name
 
 
76
 
77
  def process(action, base_model_id, dataset, system_prompt, user_prompt, schema):
78
  if action == action_1:
79
  result = fine_tune_model(base_model_id)
80
  elif action == action_2:
81
- model_id = base_model_id[base_model_id.rfind('/')+1:]
82
- fine_tuned_model_id = f"{profile}/{model_id}"
83
  result = prompt_model(fine_tuned_model_id, system_prompt, user_prompt, schema)
84
 
85
  return result
 
4
  from huggingface_hub import HfApi, login
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
6
 
7
+ # Fine-tune on NVidia A10G Large (sleep after 1 hour)
 
 
 
 
 
 
 
 
8
 
9
  profile = "bstraehle"
10
 
 
16
  schema = "CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(255), price DECIMAL(5,2), quantity INT, trade_time TIMESTAMP);"
17
 
18
  base_model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
 
19
  dataset = "gretelai/synthetic_text_to_sql"
20
 
21
  def prompt_model(model_id, system_prompt, user_prompt, schema):
 
31
 
32
  return output[0]["generated_text"][-1]["content"]
33
 
34
+ def fine_tune_model(base_model_id):
35
+ tokenizer = download_model(base_model_id)
36
+ fine_tuned_model_id = upload_model(base_model_id, tokenizer)
37
 
38
+ return fine_tuned_model_id
39
 
40
+ def download_model(base_model_id):
41
+ tokenizer = AutoTokenizer.from_pretrained(base_model_id)
42
+ model = AutoModelForCausalLM.from_pretrained(base_model_id)
43
+ model.save_pretrained(base_model_id)
44
 
45
  return tokenizer
46
 
 
48
  # ds = load_dataset(dataset)
49
  # return ""
50
 
51
+ def upload_model(base_model_id, tokenizer):
52
+ fine_tuned_model_id = replace_profile(base_model_id)
 
53
 
54
  login(token=os.environ["HF_TOKEN"])
55
 
56
  api = HfApi()
57
+ api.create_repo(repo_id=fine_tuned_model_id)
58
  api.upload_folder(
59
+ folder_path=base_model_id,
60
+ repo_id=fine_tuned_model_id)
61
  )
62
 
63
+ tokenizer.push_to_hub(fine_tuned_model_id)
64
+
65
+ return fine_tuned_model_id
66
 
67
+ def replace_profile(base_model_id):
68
+ model_id = base_model_id[base_model_id.rfind('/')+1:]
69
+ return f"{profile}/{model_id}"
70
 
71
  def process(action, base_model_id, dataset, system_prompt, user_prompt, schema):
72
  if action == action_1:
73
  result = fine_tune_model(base_model_id)
74
  elif action == action_2:
75
+ fine_tuned_model_id = replace_profile(base_model_id)
 
76
  result = prompt_model(fine_tuned_model_id, system_prompt, user_prompt, schema)
77
 
78
  return result