Edit model card

Tora_4B

  • rinna/bilingual-gpt-neox-4b-instruction-sftに、1,000,000件の和訳タスクのデータセットをフルパラメータファインチューニングしたモデルです。
  • izumi-lab/llm-japanese-datasetから翻訳タスクのデータセットを抽出し、学習に使用しました。
  • 日英翻訳タスクのデータセットを英日翻訳タスクに修正しました。
  • 日本語から英語への変換(日英翻訳)には対応していません。
  • ryota39/bilingual-gpt-neox-4b-instruction-sft-en-ja-84kも公開しておりますのでご覧ください。

学習

  • ハードウェア: 1 x NVIDIA RTX A6000(VRAM48GB)
  • 使用VRAM: 32~34GB
  • 学習時間: 3h 22m 3s
  • train/epoch: 4
  • train/loss: 1.0551
  • eval/loss: 1.550597071647644
  • optimizer: Adam
  • learning_rate: 1.5e-4
  • lr_scheduler_type: "cosine"
  • warmup_steps: 2400

学習結果

image/png

image/png

コード

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("rinna/bilingual-gpt-neox-4b-instruction-sft", use_fast=False)
model = AutoModelForCausalLM.from_pretrained("ryota39/Tora_4B")

if torch.cuda.is_available():
    model = model.to("cuda")

prompt = list()
prompt.append("指示: 以下の英語を日本語に翻訳してください。")
prompt.append("ユーザー: He delivers a presentation under the title of Stress levels estimation from facial video based on non-contact measurement of pulse wave.")
prompt.append("システム: ")
prompt = '\n'.join(prompt)
print(prompt)

token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")

with torch.no_grad():
    output_ids = model.generate(
        token_ids.to(model.device),
        max_new_tokens=512,
        do_sample=False,
        temperature=0.7,
        top_p=0.85,
        pad_token_id=tokenizer.pad_token_id,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id
    )

output = tokenizer.decode(output_ids.tolist()[0][token_ids.size(1):])
print(output)
# 彼は、顔のビデオから心拍数と心拍間隔を推定する方法について話した。
Downloads last month
13
Safetensors
Model size
3.8B params
Tensor type
BF16
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Collection including ryota39/Tora_4B