Issues to set the seed
#13
by
RINGULARITY
- opened
Hello,
I would like to make different predictions with different seeds to check if the model is confident in its answers:
model_id = "hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4"
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
quantization_config = AwqConfig(
bits=4,
fuse_max_seq_len=4096,
do_fuse=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="auto",
quantization_config=quantization_config
)
def execute_prompt(system_prompt: str, user_prompt: str, model_params: dict, n_tokens: int, seeds: List[int]):
prompt = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
model_params["do_sample"] = True
model_params["output_scores"] = True
model_params["return_dict_in_generate"] = True
results = {}
for seed in seeds:
print(seed)
set_seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
inputs = tokenizer.apply_chat_template(
prompt,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
).to("cuda")
with torch.no_grad():
outputs = model.generate(**inputs, **model_params)
generated_tokens = outputs.sequences[:, inputs['input_ids'].shape[1]:]
generated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
logits = outputs.scores
probs = [torch.nn.functional.softmax(logit, dim=-1) for logit in logits]
tokens_probs = []
for i, prob in enumerate(probs):
top_probs, top_indices = torch.topk(prob, n_tokens, dim=-1)
top_words = [tokenizer.decode([idx.item()]) for idx in top_indices[0]]
tokens_probs.append({top_words[j]: round(100 * top_probs[0][j].item(), 3) for j in range(n_tokens)})
results[seed] = (generated_text, tokens_probs)
return results
generated_text is different, however tokens_probs are the same. Params are :
model_params = {
"max_new_tokens": 20
}
n_tokens = 10
seeds = [42, 3982, 2417852, 5261, 78415, 63251]
Am I missing something ?