omitakahiro commited on
Commit
8f53238
1 Parent(s): 26274a2

Update notebooks/QLoRA.ipynb

Browse files
Files changed (1) hide show
  1. notebooks/QLoRA.ipynb +10 -1
notebooks/QLoRA.ipynb CHANGED
@@ -109,7 +109,9 @@
109
  "dataset_name = \"kunishou/databricks-dolly-15k-ja\"\n",
110
  "dataset = datasets.load_dataset(dataset_name)\n",
111
  "dataset = dataset.map(encode)\n",
 
112
  "train_dataset = dataset[\"train\"]\n",
 
113
  "\n",
114
  "# load model\n",
115
  "model = AutoModelForCausalLM.from_pretrained(model_name, device_map=\"auto\", quantization_config=bnb_config, torch_dtype=torch.float16)\n",
@@ -147,16 +149,23 @@
147
  " learning_rate=2e-4,\n",
148
  " per_device_train_batch_size=1,\n",
149
  " gradient_accumulation_steps=16,\n",
 
150
  " num_train_epochs=0.1,\n",
151
  " logging_strategy='steps',\n",
152
  " logging_steps=10,\n",
153
- " save_strategy='epoch'\n",
 
 
 
 
 
154
  ")\n",
155
  "\n",
156
  "trainer = Trainer(\n",
157
  " model=model,\n",
158
  " args=training_args,\n",
159
  " train_dataset=train_dataset,\n",
 
160
  " data_collator=get_collator(tokenizer, 256)\n",
161
  ")\n",
162
  "\n",
 
109
  "dataset_name = \"kunishou/databricks-dolly-15k-ja\"\n",
110
  "dataset = datasets.load_dataset(dataset_name)\n",
111
  "dataset = dataset.map(encode)\n",
112
+ "dataset = dataset[\"train\"].train_test_split(0.1)\n",
113
  "train_dataset = dataset[\"train\"]\n",
114
+ "eval_dataset = dataset[\"test\"]\n",
115
  "\n",
116
  "# load model\n",
117
  "model = AutoModelForCausalLM.from_pretrained(model_name, device_map=\"auto\", quantization_config=bnb_config, torch_dtype=torch.float16)\n",
 
149
  " learning_rate=2e-4,\n",
150
  " per_device_train_batch_size=1,\n",
151
  " gradient_accumulation_steps=16,\n",
152
+ " per_device_eval_batch_size=1,\n",
153
  " num_train_epochs=0.1,\n",
154
  " logging_strategy='steps',\n",
155
  " logging_steps=10,\n",
156
+ " save_strategy='epoch',\n",
157
+ " evaluation_strategy='epoch',\n",
158
+ " load_best_model_at_end=True,\n",
159
+ " metric_for_best_model=\"eval_loss\",\n",
160
+ " greater_is_better=False,\n",
161
+ " save_total_limit=2\n",
162
  ")\n",
163
  "\n",
164
  "trainer = Trainer(\n",
165
  " model=model,\n",
166
  " args=training_args,\n",
167
  " train_dataset=train_dataset,\n",
168
+ " eval_dataset=eval_dataset,\n",
169
  " data_collator=get_collator(tokenizer, 256)\n",
170
  ")\n",
171
  "\n",