beaver-7b-v1.0-cost / README.md
RuiyangSun's picture
Update README.md
c2f25b2
|
raw
history blame
2.61 kB
metadata
datasets:
  - PKU-Alignment/PKU-SafeRLHF
language:
  - en
tags:
  - reinforcement-learning-from-human-feedback
  - reinforcement-learning
  - beaver
  - safety
  - llama
  - ai-safety
  - deepspeed
  - rlhf
  - alpaca
library_name: safe-rlhf

🦫 Beaver's Cost Model

Model Details

The Beaver Cost model is a preference model trained using the PKU-SafeRLHF dataset. It can play a role in the safe RLHF algorithm, helping the Beaver model become more safe and harmless.

  • Developed by: the PKU-Alignment Team.
  • Model Type: An auto-regressive language model based on the transformer architecture.
  • License: Non-commercial license.
  • Fine-tuned from model: LLaMA, Alpaca.

Model Sources

How to Use the Cost Model

from transformers import AutoTokenizer
from safe_rlhf.models import AutoModelForScore

model = AutoModelForScore.from_pretrained('PKU-Alignment/beaver-7b-v1.0-cost', device_map='auto')
tokenizer = AutoTokenizer.from_pretrained('PKU-Alignment/beaver-7b-v1.0-cost', use_fast=False)

input = 'BEGINNING OF CONVERSATION: USER: hello ASSISTANT:Hello! How can I help you today?'

input_ids = tokenizer(input, return_tensors='pt')
output = model(**input_ids)
print(output)

# ScoreModelOutput(
#     scores=tensor([[[-19.6476],
#         [-20.2238],
#         [-21.4228],
#         [-19.2506],
#         [-20.2728],
#         [-23.8799],
#         [-22.6898],
#         [-21.5825],
#         [-21.0855],
#         [-20.2068],
#         [-23.8296],
#         [-21.4940],
#         [-21.9484],
#         [-13.1220],
#         [ -6.4499],
#         [ -8.1982],
#         [ -7.2492],
#         [ -9.3377],
#         [-13.5010],
#         [-10.4932],
#         [ -9.7837],
#         [ -6.4540],
#         [ -6.0084],
#         [ -5.8093],
#         [ -6.6134],
#         [ -5.8995],
#         [ -9.1505],
#         [-11.3254]]], grad_fn=<ToCopyBackward0>),
#     end_scores=tensor([[-11.3254]], grad_fn=<ToCopyBackward0>)
# )