Fresh_Bench / get_loss /get_loss_hf.py
jijivski's picture
if 'model_cache' not in args.__dict__: # args.model_cache=args.model args.model_cache=None
b9fc40e verified
# import packages
import os
# from tqdm import tqdm
# import warnings
import json
import torch.nn.functional as F
import torch
import gc
from transformers import AutoTokenizer, AutoModelForCausalLM
from datetime import datetime
import argparse
from types import SimpleNamespace
import pdb
# import mamba_ssm
# import rwkv
# RWKV4_TOKENIZER_FILE = "./support/20B_tokenizer.json"
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cpu'
def load_list_from_json(file_path):
"""
Loads a list of strings from a JSON file.
:param file_path: Path of the JSON file to be loaded.
:return: List of strings loaded from the JSON file.
"""
with open(file_path, 'r', encoding='utf-8') as file:
return json.load(file)
def calculate_loss(logits, target_token_ids):
# shifted_logits = logits[:-1, :]
# shifted_targets = target_token_ids[1:]
# log_probs = F.log_softmax(shifted_logits, dim=-1)
loss = torch.nn.functional.cross_entropy(logits[:-1, :].view(-1, logits.shape[-1]),
target_token_ids[1:].view(-1), reduction='none')
# pdb.set_trace()
# target_log_probs = -log_probs.gather(1, shifted_targets.unsqueeze(1)).squeeze()
# # print(target_log_probs)
# log_sum = torch.sum(target_log_probs, dim=-1)
# print(perplexity_sum)
return loss.cpu().numpy()
def calculate_log_sum(logits, target_token_ids):
shifted_logits = logits[:-1, :]
shifted_targets = target_token_ids[1:]
log_probs = F.log_softmax(shifted_logits, dim=-1)
target_log_probs = -log_probs.gather(1, shifted_targets.unsqueeze(1)).squeeze()
# print(target_log_probs)
log_sum = torch.sum(target_log_probs, dim=-1)
# print(perplexity_sum)
return log_sum.item()
def print_model_parameters_in_billions(model):
total_params = sum(p.numel() for p in model.parameters())
total_params_billion = total_params / 1e9
print(f"Model parameters: {total_params_billion:.3f} billion")
# def make_log(data_dict, folder_path):
# if not os.path.exists(folder_path):
# try:
# os.makedirs(folder_path)
# print(f"Directory created at {folder_path}")
# except Exception as e:
# print(f"Error creating directory: {e}")
# return
# timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# file_name = f"{timestamp}.json"
# file_path = os.path.join(folder_path, file_name)
# try:
# with open(file_path, 'w') as file:
# json.dump(data_dict, file, indent=4)
# print(f"Dictionary saved successfully to {file_path}")
# except Exception as e:
# print(f"Error saving dictionary: {e}")
# def load_rwkv(path):
# os.environ['RWKV_JIT_ON'] = '1'
# os.environ["RWKV_CUDA_ON"] = '1'
# from rwkv.model import RWKV
# from rwkv.utils import PIPELINE
# rwkv_model = RWKV(model=path, strategy='cuda fp16')
# rwkv_pipeline = PIPELINE(rwkv_model, r"rwkv_vocab_v20230424")
# rwkv_tokenizer = rwkv_pipeline.tokenizer
# return rwkv_model, rwkv_tokenizer
# def load_rwkv4pile(path):
# os.environ['RWKV_JIT_ON'] = '1'
# os.environ["RWKV_CUDA_ON"] = '1'
# from rwkv.model import RWKV
# from rwkv.utils import PIPELINE
# rwkv_model = RWKV(model=path, strategy='cuda fp16')
# rwkv_pipeline = PIPELINE(rwkv_model, RWKV4_TOKENIZER_FILE)
# rwkv_tokenizer = rwkv_pipeline.tokenizer
# return rwkv_model, rwkv_tokenizer
def load_hf_model(path, cache_path):
hf_tokenizer = AutoTokenizer.from_pretrained(path)
if cache_path is not None:
# pdb.set_trace()
hf_model = AutoModelForCausalLM.from_pretrained(path,
device_map=device,
trust_remote_code=True,
cache_dir=cache_path).eval()
else:
hf_model = AutoModelForCausalLM.from_pretrained(path,
device_map=device,
trust_remote_code=True).eval()
print_model_parameters_in_billions(hf_model)
return hf_model, hf_tokenizer
# def load_mamba(path):
# from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
# mamba_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
# mamba_model = MambaLMHeadModel.from_pretrained(path, device="cuda", dtype=torch.float16)
# mamba_model.device = torch.device('cuda')
# print_model_parameters_in_billions(mamba_model)
# return mamba_model, mamba_tokenizer
# def eval_rwkv(model, tokenizer, texts, chunk_size, v4pile=False):
# rwkv_test_data = []
# rwkv_token_length_list = []
# for idx, sample in tqdm(enumerate(texts), total=len(texts)):
# with torch.no_grad():
# if v4pile:
# input_seq = tokenizer.encode(sample).ids # v4
# else:
# input_seq = tokenizer.encode(sample)
# input_length = len(input_seq)
# neg_log_prob_temp = 0
# # for begin in range(0, input_length, chunk_size):
# input_chunk = input_seq[:chunk_size]
# logit = model.forward(input_chunk, None, full_output=True)[0]
# if len(input_chunk) == 1:
# logit = logit.unsqueeze(0)
# log_sum = calculate_log_sum(logit, torch.tensor(input_chunk).cuda())
# neg_log_prob_temp += log_sum
# rwkv_token_length_list.append(input_length)
# rwkv_test_data.append(neg_log_prob_temp)
# data_dict = {
# 'neg_log_prob_sum': sum(rwkv_test_data) / len(rwkv_test_data),
# 'avg tokens': sum(rwkv_token_length_list) / len(rwkv_token_length_list),
# }
# print(f'log probability sum: {sum(rwkv_test_data) / len(rwkv_test_data):.2f}')
# print(f'avg tokens: {sum(rwkv_token_length_list) / len(rwkv_token_length_list):.0f}')
return logit,logit,input_chunk,tokenizer
def eval_hf_model(model, tokenizer, texts, chunk_size):
data = []
token_length_list = []
# for idx, sample in tqdm(enumerate(texts), total=len(texts)):#TODO deleta the forloop
with torch.no_grad():
inputs = tokenizer(texts, return_tensors='pt')
inputs = inputs.to(model.device)
seq_length = inputs['input_ids'].shape[-1]
neg_log_prob_temp = 0
# for begin in range(0, seq_length, chunk_size):
input_chunk = inputs['input_ids'][:, :chunk_size]
logit = model.forward(input_ids=input_chunk).logits[0, :, :]
log_sum = calculate_log_sum(logit, input_chunk.squeeze(0))# suppose shape of logit is (seq_length, vocab_size),shape of input_chunk is (,seq_length)
neg_log_prob_temp += log_sum
loss = calculate_loss(logit, input_chunk.squeeze(0))
# token_length_list.append(seq_length)
# data.append(neg_log_prob_temp)
# data_dict = {
# 'neg_log_prob_sum': sum(data) / len(data),
# 'avg tokens': sum(token_length_list) / len(token_length_list),
# }
# print(f'log probability sum: {sum(data) / len(data):.2f}')
# print(f'avg tokens: {sum(token_length_list) / len(token_length_list):.0f}')
rtn_dic={'logit':logit.cpu().numpy(),'input_ids':input_chunk.cpu().numpy()[0],'loss':loss,'tokenizer':tokenizer,'neg_log_prob_temp':neg_log_prob_temp}
return rtn_dic
# if __name__ == '__main__':
# parser = argparse.ArgumentParser()
# parser.add_argument('--model', type=str, required=True, help='model name or path')
# parser.add_argument('--model_type', choices=['hf', 'rwkv', 'mamba', 'rwkv4pile'], required=True, help='model type')
# parser.add_argument('--data', type=str, required=True, help='data path (json file)')
# parser.add_argument('--log_path', type=str, default='./logs/', help='log file path')
# parser.add_argument('--model_cache', type=str, help='hugging face model cache')
# parser.add_argument('--chunk_size', type=int, default=1024, help='chunk size')
def run_get_loss(args=None):
if args is None:
# args=SimpleNamespace(model='microsoft/phi-2',texts='Hello FreshBench !',model_type='hf',model_cache=None,chunk_size=1024)
args=SimpleNamespace(model='/home/sribd/chenghao/models/phi-2',texts='Hello FreshBench !',model_type='hf',model_cache=None,chunk_size=1024)
if 'chunk_size' not in args.__dict__:
args.chunk_size=1024
if 'model_type' not in args.__dict__:
args.model_type='hf'
if 'model' not in args.__dict__ or len(args.model)<2:
# args.model='/home/sribd/chenghao/models/phi-2'
args.model='microsoft/phi-2'
if 'model_cache' not in args.__dict__:
# args.model_cache=args.model
args.model_cache=None
# args = parser.parse_args()
# load data
# texts = load_list_from_json(args.data)
print('args',args)
texts=args.texts
print(f'data size: {len(texts)}')
# load model
if args.model_type == 'hf':
model, tokenizer = load_hf_model(args.model, args.model_cache)# tokenzier path, model path
# elif args.model_type == 'rwkv':
# model, tokenizer = load_rwkv(args.model)
# elif args.model_type == 'mamba':
# model, tokenizer = load_mamba(args.model)
# elif args.model_type == 'rwkv4pile':
# model, tokenizer = load_rwkv4pile(args.model)
else:
raise NotImplementedError
# eval
if args.model_type in ['hf', 'mamba']:
print(f'eval hf')
return eval_hf_model(model=model, tokenizer=tokenizer, texts=texts, chunk_size=args.chunk_size)
# elif args.model_type == 'rwkv':
# return eval_rwkv(model=model, tokenizer=tokenizer, texts=texts, chunk_size=args.chunk_size)
# elif args.model_type == 'rwkv4pile':
# return eval_rwkv(model=model, tokenizer=tokenizer, texts=texts, chunk_size=args.chunk_size, v4pile=True)
else:
raise NotImplementedError
# results['model_name_or_path'] = args.model
# results['data_path'] = args.data
# results['chunk_size'] = args.chunk_size
# make_log(results, args.log_path)
# print(json.dumps(results, indent=4, ensure_ascii=False))
if __name__ == '__main__':
args=SimpleNamespace(model='microsoft/phi-2',texts='Hello FreshBench !',model_type='hf',model_cache=None,chunk_size=1024)
run_get_loss(args)
# run_get_loss(args)