File size: 4,469 Bytes
ee20756 d890369 ee20756 d890369 ee20756 d890369 499995a d890369 499995a d890369 499995a d890369 499995a d890369 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
---
library_name: peft
base_model: google/gemma-2b
license: mit
tags:
- Mathematical Reasoning
language:
- en
datasets:
- adityasihag/math_QAaugP
---
**This repo contains LoRA adapter weights**.
### Model Description
- **Project GitHub Page:** https://github.com/adityasihag1996/math_QA.git
- **Developed by:** [Aditya Sihag](https://www.linkedin.com/in/aditya-sihag-ab29681a9/)
- **Model type:** fine-tuned using QLoRA on 1x RTX 4090
- **Finetuned from model:** google/gemma-2b
## Results
<table>
<thead>
<tr>
<th>Prompt Approach</th>
<th>GSM8k</th>
<th>MATH</th>
</tr>
</thead>
<tbody>
<tr>
<td>Zero-Shot CoT</td>
<td><b>43.66</b></td>
<td><b>-</b></td>
</tr>
</tbody>
</table>
## Training procedure
The following `bitsandbytes` quantization config was used during training:
- quant_method: bitsandbytes
- load_in_8bit: False
- load_in_4bit: True
- bnb_4bit_quant_type: nf4
- bnb_4bit_use_double_quant: True
- bnb_4bit_compute_dtype: float16
`LoraConfig` params:
- r: 128
- lora_alpha: lora_r * 2
- lora_dropout: 0.05
- bias: "none"
- task_type: "CAUSAL_LM"
- target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
The hyperparameters for the LoRA fine-tuning are listed below:
- epochs: 3
- learning_rate: 5e-5
- batch_size: 256
- max_grad_norm: 1.0
- weight_decay: 0.001
- lr_scheduler_type: "cosine"
- warmup_ratio: 0.03
## Dataset
math_QA dataset is prepared as combination of [MetaMathQA](https://huggingface.co/datasets/meta-math/MetaMathQA) and [MathInstruct](https://huggingface.co/datasets/TIGER-Lab/MathInstruct), and some internal data.
Refer [math_QAaugP](https://huggingface.co/datasets/adityasihag/math_QAaugP)
## Model Usage
```
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer
)
from peft import PeftModel
model_path = "google/gemma-2b"
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype = torch.float16,
device_map = {"": 0},
)
# Load LoRA and merge
model = PeftModel.from_pretrained(model, "adityasihag/math_QA-gemma-2B-QLoRA-adapter")
model = model.merge_and_unload()
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
question = """Gretchen has 110 coins. There are 30 more gold coins than silver coins. How many gold coins does Gretchen have?"""
sample_input = f"""Question: {question} \n Answer: """
sample_input_tokenised = tokenizer(sample_input, return_tensors = "pt").to("cuda")
generated_ids = model.generate(
**sample_input_tokenised,
max_new_tokens = 512,
temperature = 0.3
)
output = tokenizer.decode(generated_ids[0], skip_special_tokens = True)
print(output)
```
##### Sample Input:
```
Question: Gretchen has 110 coins. There are 30 more gold coins than silver coins. How many gold coins does Gretchen have? \n Answer:
```
##### Model Output:
```
Let's assume the number of silver coins is x.
Since there are 30 more gold coins than silver coins, the number of gold coins is x + 30.
The total number of coins is x + (x + 30) = 110.
Combining like terms, we get 2x + 30 = 110.
Subtracting 30 from both sides, we get 2x = 80.
Dividing both sides by 2, we get x = 40.
So, Gretchen has 40 silver coins and 40 + 30 = 70 gold coins.
The answer is: 70
```
#### Prompt Template:
```
Question: <question>
Answer:
```
## Comparing math_QA models with other SFT LLM models
| Model | GSM8k Pass@1 | MATH Pass@1 |
|---------------------|--------------|-------------|
| LLaMA-2-7B | 14.6 | 2.5 |
| gemma-2b | 17.7 | |
| LLaMA-2-13B | 28.7 | 3.9 |
| LLaMA-2-34B | 42.2 | 6.24 |
| **math_QA-gemma-2B** | **43.66** | |
| gemma-7b | 46.4 | |
| WizardMath-7B | 54.9 | 10.7 |
| Mistral-7B | 35.4 | |
| WizardMath-13B | 63.9 | 14.0 |
| MetaMath-7B | 66.5 | 19.8 |
| MetaMath-13B | 72.3 | 22.4 |
| **math_QA-Mistral-7B** | **75.81** | |
| Arithmo2-Mistral-7B | 76.4 | 27.2 |
| MetaMath-Mistral-7B | 77.7 | 28.2 |
| DeepSeekMath-Instruct-7B | 82.9 | 46.8 |
| GPT4 | 92.0 | 52.9 |
|