metadata
license: gemma
library_name: peft
tags:
- trl
- sft
- generated_from_trainer
base_model: google/gemma-1.1-2b-it
model-index:
- name: gemma-2b-it-example-v1
results: []
language:
- ko
Model Description
git hub : https://github.com/aiqwe/instruction-tuning-with-rag-example
Instruction Tuning์ ํ์ต์ ์ํด ์์๋ก ํ์ตํ ๋ชจ๋ธ์
๋๋ค.
gemma-2b-it ๋ชจ๋ธ์ ๊ธฐ๋ฐ์ผ๋ก ์ฝ 1๋ง๊ฐ์ ๋ถ๋์ฐ ๊ด๋ จ Instruction ๋ฐ์ดํฐ์
์ ํ์ตํ์์ต๋๋ค.
ํ์ต ์ฝ๋๋ ์ git hub๋ฅผ ์ฐธ์กฐํด์ฃผ์ธ์.
Usage
Inference on GPU example
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
model = AutoModelForCausalLM.from_pretrained(
"aiqwe/gemma-2b-it-example-v1",
device_map="cuda",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2"
)
input_text = "์ํํธ ์ฌ๊ฑด์ถ์ ๋ํด ์๋ ค์ค."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids, max_new_tokens=512)
print(tokenizer.decode(outputs[0]))
Inference on CPU example
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
model = AutoModelForCausalLM.from_pretrained(
"aiqwe/gemma-2b-it-example-v1",
device_map="cpu",
torch_dtype=torch.bfloat16
)
input_text = "์ํํธ ์ฌ๊ฑด์ถ์ ๋ํด ์๋ ค์ค."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids, max_new_tokens=512)
print(tokenizer.decode(outputs[0]))
Inference on GPU with embedded function example
๋ด์ฅ๋ ํจ์๋ก ๋ค์ด๋ฒ ๊ฒ์ API๋ฅผ ํตํด RAG๋ฅผ ์ง์๋ฐ์ต๋๋ค.
from transformers import AutoTokenizer, AutoModelForCausalLM
from utils import generate
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
model = AutoModelForCausalLM.from_pretrained(
"aiqwe/gemma-2b-it-example-v1",
device_map="cuda",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2"
)
rag_config = {
"api_client_id": userdata.get('NAVER_API_ID'),
"api_client_secret": userdata.get('NAVER_API_SECRET')
}
completion = generate(
model=model,
tokenizer=tokenizer,
query=query,
max_new_tokens=512,
rag=True,
rag_config=rag_config
)
print(completion)
Chat Template
Gemma ๋ชจ๋ธ์ Chat Template์ ์ฌ์ฉํฉ๋๋ค.
gemma-2b-it Chat Template
input_text = "์ํํธ ์ฌ๊ฑด์ถ์ ๋ํด ์๋ ค์ค."
input_text = tokenizer.apply_chat_template(
conversation=[
{"role": "user", "content": input_text}
],
add_generate_prompt=True,
return_tensors="pt"
).to(model.device)
outputs = model.generate(input_text, max_new_tokens=512, repetition_penalty = 1.5)
print(tokenizer.decode(outputs[0], skip_special_tokens=False))
Training information
ํ์ต์ ๊ตฌ๊ธ ์ฝ๋ฉ L4 Single GPU๋ฅผ ํ์ฉํ์์ต๋๋ค.
๊ตฌ๋ถ | ๋ด์ฉ |
---|---|
ํ๊ฒฝ | Google Colab |
GPU | L4(22.5GB) |
์ฌ์ฉ VRAM | ์ฝ 13.8GB |
dtype | bfloat16 |
Attention | flash attention2 |
Tuning | Lora(r=4, alpha=32) |
Learning Rate | 1e-4 |
LRScheduler | Cosine |
Optimizer | adamw_torch_fused |
batch_size | 4 |
gradient_accumulation_steps | 2 |