Spaces:
Runtime error
Runtime error
import copy | |
import gc | |
import tempfile | |
import unittest | |
from functools import lru_cache | |
import torch | |
import transformers | |
from hypothesis import given, settings | |
from hypothesis import strategies as st | |
from trlx.data.default_configs import default_ilql_config | |
from trlx.models.modeling_ilql import ( | |
AutoModelForCausalLMWithILQLHeads, | |
AutoModelForSeq2SeqLMWithILQLHeads, | |
ILQLBatch, | |
ILQLConfig, | |
ILQLHeads, | |
batched_index_select, | |
) | |
from trlx.models.modeling_ppo import ( | |
AutoModelForCausalLMWithHydraValueHead, | |
AutoModelForCausalLMWithValueHead, | |
AutoModelForSeq2SeqLMWithHydraValueHead, | |
AutoModelForSeq2SeqLMWithValueHead, | |
) | |
from trlx.trainer.accelerate_ilql_trainer import make_experience | |
AUTO_CAUSAL_LM_PATHS = ["gpt2", "EleutherAI/pythia-160m", "facebook/opt-125m"] | |
AUTO_SEQ2SEQ_LM_PATHS = ["t5-small", "google/flan-t5-small"] | |
# Value Head Modeling Tests | |
class TestAutoModelForCausalLMWithValueHead(unittest.TestCase): | |
_auto_model_class = AutoModelForCausalLMWithValueHead | |
_supported_args = {} | |
def setUp(self): | |
self.text = "Once upon a time there was a happy goose named Louis. He liked to eat bananas." | |
def tearDown(self): | |
gc.collect() # Try to free up memory | |
def _create_inputs(self, model_path): | |
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) | |
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) | |
tokenizer.padding_side = "left" | |
tokenized = tokenizer(self.text, truncation=True, padding="max_length", max_length=4, return_tensors="pt") | |
return dict(input_ids=tokenized.input_ids, attention_mask=tokenized.attention_mask) | |
def test_forward(self): | |
for model_path in AUTO_CAUSAL_LM_PATHS: | |
model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) | |
inputs = self._create_inputs(model_path) | |
# Ensure that the `forward` method doesn't throw an error on generic inputs | |
try: | |
model(**inputs) | |
except Exception as e: | |
self.assertFalse(True, msg=e) | |
def test_generate(self): | |
for model_path in AUTO_CAUSAL_LM_PATHS: | |
model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) | |
inputs = self._create_inputs(model_path) | |
# Ensure that the `generate` method doesn't throw an error on generic inputs | |
try: | |
model.generate(**inputs, return_dict=True, output_hidden_states=True) | |
except Exception as e: | |
self.assertFalse(True, msg=e) | |
def test_save_load(self): | |
for model_path in AUTO_CAUSAL_LM_PATHS: | |
model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) | |
modified_model = copy.deepcopy(model) | |
# Manually modify value head parameters | |
modified_model.v_head[-1].bias = torch.nn.Parameter(torch.tensor([6000053.33])) | |
with tempfile.TemporaryDirectory() as tmpdirname: | |
modified_model.save_pretrained(tmpdirname) | |
loaded_model = self._auto_model_class.from_pretrained(tmpdirname) | |
# Check that the loaded model state dict is the same as the saved model state dict | |
loaded_state_dict = loaded_model.state_dict() | |
self.assertEqual(modified_model.state_dict().keys(), loaded_state_dict.keys()) | |
for name, saved_state in modified_model.state_dict().items(): | |
self.assertTrue(torch.all(torch.isclose(saved_state, loaded_state_dict[name]))) | |
# Assert loaded states are not the same as the original unmodified pretrained model | |
self.assertFalse(torch.all(torch.isclose(modified_model.v_head[-1].bias, model.v_head[-1].bias))) | |
def test_from_config(self): | |
for model_path in AUTO_CAUSAL_LM_PATHS: | |
config = transformers.AutoConfig.from_pretrained(model_path) | |
# Modify the config to ensure the model is initialized from the custom config | |
config.vocab_size = 2 | |
model = self._auto_model_class.from_config(config, **self._supported_args) | |
self.assertEqual(model.base_model.get_output_embeddings().out_features, config.vocab_size) | |
class TestAutoModelForCausalLMWithHydraValueHead(TestAutoModelForCausalLMWithValueHead): | |
_auto_model_class = AutoModelForCausalLMWithHydraValueHead | |
_supported_args = {"num_layers_unfrozen": 2} # TODO: Test various values | |
def test_forward(self): | |
for model_path in AUTO_CAUSAL_LM_PATHS: | |
model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) | |
inputs = self._create_inputs(model_path) | |
with torch.no_grad(): | |
# Compare logits and hidden states from frozen and unfrozen heads | |
unfrozen_outputs = model(**inputs, return_dict=True, output_hidden_states=True) | |
unfrozen_last_hidden_state = unfrozen_outputs.hidden_states[-1] | |
unfrozen_logits = unfrozen_outputs.logits | |
frozen_outputs = model.forward_hydra(**inputs, return_dict=True, output_hidden_states=True) | |
frozen_last_hidden_state = frozen_outputs.hidden_states[-1] | |
frozen_logits = frozen_outputs.logits | |
hs_diff = torch.sum(unfrozen_last_hidden_state - frozen_last_hidden_state).item() | |
logits_diff = torch.sum(unfrozen_logits - frozen_logits).item() | |
self.assertEqual(hs_diff, 0) | |
self.assertEqual(logits_diff, 0) | |
def test_lm_heads(self): | |
for model_path in AUTO_CAUSAL_LM_PATHS: | |
model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) | |
inputs = self._create_inputs(model_path) | |
# Compare frozen and unfrozen logits | |
with torch.no_grad(): | |
unfrozen_outputs = model(**inputs, return_dict=True, output_hidden_states=True) | |
unfrozen_logits = unfrozen_outputs.logits | |
frozen_logits = model.frozen_head.lm_head(unfrozen_outputs.hidden_states[-1].to(torch.float32)) | |
diff = torch.sum(unfrozen_logits - frozen_logits).item() | |
self.assertEqual(diff, 0) | |
def test_frozen_head(self): | |
for model_path in AUTO_CAUSAL_LM_PATHS: | |
model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) | |
# Ensure that all parameters of the hyrda `model.frozen_head` are actually frozen | |
for parameter in model.frozen_head.parameters(): | |
self.assertTrue(parameter.requires_grad is False) | |
class TestAutoModelForSeq2SeqLMWithValueHead(unittest.TestCase): | |
_auto_model_class = AutoModelForSeq2SeqLMWithValueHead | |
_supported_args = {} | |
def setUp(self): | |
self.encoder_text = "Translate this text to French: Hello, my dog is cute" | |
self.decoder_text = "Bonjour, mon chien est mignon" | |
def tearDown(self): | |
gc.collect() # Try to free up memory | |
def _create_inputs(self, model_path): | |
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) | |
tokenizer.sep_token = "<sep>" | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.padding_side = "left" | |
encoder_inputs = tokenizer( | |
self.encoder_text, truncation=True, padding="max_length", max_length=10, return_tensors="pt" | |
) | |
decoder_inputs = tokenizer(self.decoder_text, return_tensors="pt") | |
return { | |
**encoder_inputs, | |
"decoder_input_ids": decoder_inputs.input_ids, | |
"decoder_attention_mask": decoder_inputs.attention_mask, | |
} | |
def test_forward(self): | |
for model_path in AUTO_SEQ2SEQ_LM_PATHS: | |
model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) | |
inputs = self._create_inputs(model_path) | |
# Ensure that the `forward` method doesn't throw an error on generic inputs | |
try: | |
model(**inputs) | |
except Exception as e: | |
self.assertFalse(True, msg=e) | |
def test_generate(self): | |
for model_path in AUTO_SEQ2SEQ_LM_PATHS: | |
model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) | |
inputs = self._create_inputs(model_path) | |
# Ensure that the `generate` method doesn't throw an error on generic inputs | |
try: | |
model.generate(inputs["input_ids"]) | |
except Exception as e: | |
self.assertFalse(True, msg=e) | |
def test_save_load(self): | |
for model_path in AUTO_SEQ2SEQ_LM_PATHS: | |
model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) | |
modified_model = copy.deepcopy(model) | |
# Manually modify value head parameters | |
modified_model.v_head[-1].bias = torch.nn.Parameter(torch.tensor([6000053.33])) | |
with tempfile.TemporaryDirectory() as tmpdirname: | |
modified_model.save_pretrained(tmpdirname) | |
loaded_model = self._auto_model_class.from_pretrained(tmpdirname) | |
# Check that the loaded model state dict is the same as the saved model state dict | |
loaded_state_dict = loaded_model.state_dict() | |
self.assertEqual(modified_model.state_dict().keys(), loaded_state_dict.keys()) | |
for name, saved_state in modified_model.state_dict().items(): | |
self.assertTrue(torch.all(torch.isclose(saved_state, loaded_state_dict[name]))) | |
# Assert loaded states are not the same as the original unmodified pretrained model | |
self.assertFalse(torch.all(torch.isclose(modified_model.v_head[-1].bias, model.v_head[-1].bias))) | |
def test_from_config(self): | |
for model_path in AUTO_SEQ2SEQ_LM_PATHS: | |
config = transformers.AutoConfig.from_pretrained(model_path) | |
# Modify the config to ensure the model is initialized from the custom config | |
config.vocab_size = 2 | |
model = self._auto_model_class.from_config(config, **self._supported_args) | |
self.assertEqual(model.base_model.get_output_embeddings().out_features, config.vocab_size) | |
class TestAutoModelForSeq2SeqLMWithHydraValueHead(TestAutoModelForSeq2SeqLMWithValueHead): | |
_auto_model_class = AutoModelForSeq2SeqLMWithHydraValueHead | |
_supported_args = {"num_layers_unfrozen": 2} # TODO: Test various values | |
def test_forward(self): | |
for model_path in AUTO_SEQ2SEQ_LM_PATHS: | |
model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) | |
inputs = self._create_inputs(model_path) | |
with torch.no_grad(): | |
# Compare logits and hidden states from frozen and unfrozen heads | |
unfrozen_outputs = model(**inputs, return_dict=True, output_hidden_states=True) | |
unfrozen_last_hidden_state = unfrozen_outputs.decoder_hidden_states[-1] | |
unfrozen_logits = unfrozen_outputs.logits | |
frozen_outputs = model.forward_hydra(**inputs, return_dict=True, output_hidden_states=True) | |
frozen_last_hidden_state = frozen_outputs.decoder_hidden_states[-1] | |
frozen_logits = frozen_outputs.logits | |
hs_diff = torch.sum(unfrozen_last_hidden_state - frozen_last_hidden_state).item() | |
logits_diff = torch.sum(unfrozen_logits - frozen_logits).item() | |
self.assertEqual(hs_diff, 0) | |
self.assertEqual(logits_diff, 0) | |
def test_lm_heads(self): | |
for model_path in AUTO_SEQ2SEQ_LM_PATHS: | |
model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) | |
inputs = self._create_inputs(model_path) | |
# Compare frozen and unfrozen logits | |
with torch.no_grad(): | |
unfrozen_outputs = model(**inputs, return_dict=True, output_hidden_states=True) | |
unfrozen_logits = unfrozen_outputs.logits | |
last_hidden_state = unfrozen_outputs.decoder_hidden_states[-1] | |
frozen_logits = model.frozen_head.lm_head(last_hidden_state) | |
diff = torch.sum(unfrozen_logits - frozen_logits).item() | |
self.assertEqual(diff, 0) | |
def test_frozen_head(self): | |
for model_path in AUTO_SEQ2SEQ_LM_PATHS: | |
model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) | |
# Ensure that all parameters of the hyrda `model.frozen_head` are actually frozen | |
for parameter in model.frozen_head.parameters(): | |
self.assertTrue(parameter.requires_grad is False) | |
# ILQL Heads Modeling Tests | |
class TestAutoModelForCausalLMWithILQLHeads(unittest.TestCase): | |
_auto_model_class = AutoModelForCausalLMWithILQLHeads | |
_supported_args = {"two_qs": True, "alpha": 0.8} # TODO: Test various values | |
def setUp(self): | |
self.text = "Once upon a time there was a happy goose named Louis. He liked to eat bananas." | |
def tearDown(self): | |
gc.collect() # Try to free up memory | |
def _create_inputs(self, model_path): | |
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) | |
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) | |
tokenizer.padding_side = "left" | |
return tokenizer(self.text, truncation=True, padding="max_length", max_length=4, return_tensors="pt") | |
def test_forward(self): | |
for model_path in AUTO_CAUSAL_LM_PATHS: | |
model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) | |
inputs = self._create_inputs(model_path) | |
# Ensure that the `forward` method doesn't throw an error on generic inputs | |
try: | |
model(**inputs) | |
except Exception as e: | |
self.assertFalse(True, msg=e) | |
def test_generate(self): | |
for model_path in AUTO_CAUSAL_LM_PATHS: | |
model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) | |
inputs = self._create_inputs(model_path) | |
# Ensure that the `generate` method doesn't throw an error on generic inputs | |
try: | |
model.generate(**inputs) | |
except Exception as e: | |
self.assertFalse(True, msg=e) | |
def test_save_load(self): | |
for model_path in AUTO_CAUSAL_LM_PATHS: | |
model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) | |
modified_model = copy.deepcopy(model) | |
# Manually modify value head parameters | |
modified_model.ilql_heads.q_heads[0][0].bias = torch.nn.Parameter( | |
torch.ones_like(modified_model.ilql_heads.q_heads[0][0].bias) * 600053.34 | |
) | |
with tempfile.TemporaryDirectory() as tmpdirname: | |
modified_model.save_pretrained(tmpdirname) | |
loaded_model = self._auto_model_class.from_pretrained(tmpdirname) | |
# Check that the loaded model state dict is the same as the saved model state dict | |
loaded_state_dict = loaded_model.state_dict() | |
self.assertEqual(modified_model.state_dict().keys(), loaded_state_dict.keys()) | |
for name, saved_state in modified_model.state_dict().items(): | |
self.assertTrue(torch.all(torch.isclose(saved_state, loaded_state_dict[name]))) | |
# Assert loaded states are not the same as the original unmodified pretrained model | |
self.assertFalse( | |
torch.all( | |
torch.isclose(modified_model.ilql_heads.q_heads[0][0].bias, model.ilql_heads.q_heads[0][0].bias) | |
) | |
) | |
def test_from_config(self): | |
for model_path in AUTO_CAUSAL_LM_PATHS: | |
config = transformers.AutoConfig.from_pretrained(model_path) | |
# Modify the config to ensure the model is initialized from the custom config | |
config.vocab_size = 2 | |
model = self._auto_model_class.from_config(config, **self._supported_args) | |
self.assertEqual(model.base_model.get_output_embeddings().out_features, config.vocab_size) | |
class TestAutoModelForSeq2SeqLMWithILQLHeads(unittest.TestCase): | |
_auto_model_class = AutoModelForSeq2SeqLMWithILQLHeads | |
_supported_args = {"two_qs": True, "alpha": 0.8} # TODO: Test various values | |
def setUp(self): | |
self.text = "Once upon a time there was a happy goose named Louis. He liked to eat bananas." | |
def tearDown(self): | |
gc.collect() # Try to free up memory | |
def _create_inputs(self, model_path): | |
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) | |
tokenizer.padding_side = "left" | |
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) | |
inputs = tokenizer(self.text, truncation=True, padding="max_length", max_length=4, return_tensors="pt") | |
inputs["decoder_input_ids"] = torch.tensor([[tokenizer.pad_token_id]]) | |
return inputs | |
def test_forward(self): | |
for model_path in AUTO_SEQ2SEQ_LM_PATHS: | |
model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) | |
inputs = self._create_inputs(model_path) | |
# Ensure that the `forward` method doesn't throw an error on generic inputs | |
try: | |
model(**inputs) | |
except Exception as e: | |
self.assertFalse(True, msg=e) | |
def test_generate(self): | |
for model_path in AUTO_SEQ2SEQ_LM_PATHS: | |
model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) | |
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) | |
inputs = self._create_inputs(model_path) | |
inputs["pad_token_id"] = tokenizer.pad_token_id | |
inputs["eos_token_id"] = tokenizer.eos_token_id | |
# Ensure that the `generate` method doesn't throw an error on generic inputs | |
try: | |
model.generate(**inputs) | |
except Exception as e: | |
self.assertFalse(True, msg=e) | |
def test_save_load(self): | |
for model_path in AUTO_SEQ2SEQ_LM_PATHS: | |
model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) | |
modified_model = copy.deepcopy(model) | |
# Manually modify value head parameters | |
modified_model.ilql_heads.q_heads[0][0].bias = torch.nn.Parameter( | |
torch.ones_like(modified_model.ilql_heads.q_heads[0][0].bias) * 600053.34 | |
) | |
with tempfile.TemporaryDirectory() as tmpdirname: | |
modified_model.save_pretrained(tmpdirname) | |
loaded_model = self._auto_model_class.from_pretrained(tmpdirname) | |
# Check that the loaded model state dict is the same as the saved model state dict | |
loaded_state_dict = loaded_model.state_dict() | |
self.assertEqual(modified_model.state_dict().keys(), loaded_state_dict.keys()) | |
for name, saved_state in modified_model.state_dict().items(): | |
self.assertTrue(torch.all(torch.isclose(saved_state, loaded_state_dict[name]))) | |
# Assert loaded states are not the same as the original unmodified pretrained model | |
self.assertFalse( | |
torch.all( | |
torch.isclose(modified_model.ilql_heads.q_heads[0][0].bias, model.ilql_heads.q_heads[0][0].bias) | |
) | |
) | |
def test_from_config(self): | |
for model_path in AUTO_SEQ2SEQ_LM_PATHS: | |
config = transformers.AutoConfig.from_pretrained(model_path) | |
# Modify the config to ensure the model is initialized from the custom config | |
config.vocab_size = 2 | |
model = self._auto_model_class.from_config(config, **self._supported_args) | |
self.assertEqual(model.base_model.get_output_embeddings().out_features, config.vocab_size) | |
def test_batched_index_select(batch, seq_len, num_idxes, hidden): | |
x = torch.randn(batch, seq_len, hidden) | |
if seq_len > 0: | |
idxs = torch.randint(0, seq_len, (batch, num_idxes)) | |
else: | |
idxs = torch.zeros(batch, num_idxes, dtype=torch.long) | |
out = batched_index_select(x, idxs, dim=1) | |
# Compute output using for loop | |
out2 = torch.zeros(batch, num_idxes, hidden) | |
for i in range(batch): | |
out2[i] = x[i, idxs[i]] | |
assert (out == out2).all() | |
def test_ilql_heads_indexing(batch_size, seq_len, num_action_idxs, num_state_idxs, hidden_size, vocab_size, two_qs): | |
heads = ILQLHeads(hidden_size, vocab_size, two_qs, alpha=1.0, dtype=torch.float32) | |
# heads(hidden_states, states_ixs, actions_ixs) should | |
# == heads(hidden_states) followed by indexing with states_ixs and actions_ixs | |
hidden_states = torch.randn(batch_size, seq_len, hidden_size) | |
states_ixs = torch.randint(0, seq_len, (batch_size, num_state_idxs)) | |
actions_ixs = torch.randint(0, seq_len, (batch_size, num_action_idxs)) | |
qs, target_qs, vs = heads(hidden_states, states_ixs, actions_ixs) | |
qs2, target_qs2, vs2 = heads(hidden_states) | |
assert len(qs2) == len(target_qs2) == len(qs) | |
qs2 = tuple(batched_index_select(q, actions_ixs, dim=1) for q in qs2) | |
target_qs2 = tuple(batched_index_select(q, actions_ixs, dim=1) for q in target_qs2) | |
vs2 = batched_index_select(vs2, states_ixs, dim=1) | |
assert all(torch.allclose(q, q2, atol=1e-06) for q, q2 in zip(qs, qs2)) | |
assert all(torch.allclose(q, q2, atol=1e-06) for q, q2 in zip(target_qs, target_qs2)) | |
assert torch.allclose(vs, vs2, atol=1e-06) | |
def test_ilql_heads_output_count_and_shape( | |
batch_size, seq_len, num_action_idxs, num_state_idxs, hidden_size, vocab_size, two_qs | |
): | |
heads = ILQLHeads(hidden_size, vocab_size, two_qs, alpha=1.0, dtype=torch.float32) | |
hidden_states = torch.randn(batch_size, seq_len, hidden_size) | |
states_ixs = torch.randint(0, seq_len, (batch_size, num_state_idxs)) | |
actions_ixs = torch.randint(0, seq_len, (batch_size, num_action_idxs)) | |
qs, target_qs, vs = heads(hidden_states, states_ixs, actions_ixs) | |
assert len(qs) == len(target_qs) | |
assert qs[0].shape == (batch_size, num_action_idxs, vocab_size) | |
assert target_qs[0].shape == (batch_size, num_action_idxs, vocab_size) | |
assert vs.shape == (batch_size, num_state_idxs, 1) | |
if two_qs: | |
assert len(qs) == 2 | |
assert qs[1].shape == (batch_size, num_action_idxs, vocab_size) | |
assert target_qs[1].shape == (batch_size, num_action_idxs, vocab_size) | |
else: | |
assert len(qs) == 1 | |
def test_ilql_heads_alpha(hidden_size, vocab_size, alpha, two_qs): | |
heads = ILQLHeads(hidden_size, vocab_size, two_qs, alpha=alpha, dtype=torch.float32) | |
for q_head in heads.q_heads: | |
for param in q_head.parameters(): | |
param.data.copy_(torch.ones_like(param.data)) | |
for target_q_head in heads.target_q_heads: | |
for param in target_q_head.parameters(): | |
param.data.copy_(torch.zeros_like(param.data)) | |
heads.sync_target_q_heads() | |
for target_q_head in heads.target_q_heads: | |
for param in target_q_head.parameters(): | |
assert torch.allclose(param.data, alpha * torch.ones_like(param.data), atol=1e-06) | |
def test_ilql_loss_doesnt_crash(batch_size, seq_len, num_action_idxs, hidden_size, vocab_size, two_qs): | |
ilql_config: ILQLConfig = default_ilql_config().method | |
ilql_config.two_qs = two_qs | |
num_state_idxs = num_action_idxs + 1 | |
heads = ILQLHeads(hidden_size, vocab_size, two_qs, alpha=1.0, dtype=torch.float32) | |
hidden_states = torch.randn(batch_size, seq_len, hidden_size) | |
states_ixs = torch.randint(0, seq_len, (batch_size, num_state_idxs)) | |
actions_ixs = torch.randint(0, seq_len, (batch_size, num_action_idxs)) | |
qs, target_qs, vs = heads(hidden_states, states_ixs, actions_ixs) | |
logits = torch.randn(batch_size, seq_len, vocab_size) | |
labels = ILQLBatch( | |
input_ids=torch.randint(0, vocab_size, (batch_size, seq_len + 1)), | |
attention_mask=torch.ones(batch_size, seq_len, dtype=torch.bool), | |
rewards=torch.randn(batch_size, num_action_idxs), | |
states_ixs=states_ixs, | |
actions_ixs=actions_ixs, | |
dones=torch.randint(0, 2, (batch_size, num_state_idxs), dtype=torch.bool), | |
) | |
loss_input = logits, (qs, target_qs, vs) | |
loss, stats = ilql_config.loss(loss_input, labels) | |
def cached_tokenizer(): | |
return transformers.AutoTokenizer.from_pretrained("gpt2") | |
def test_ilql_loss_make_experience_single_turn(samples_rewards, hidden_size, two_qs): | |
samples, rewards = zip(*samples_rewards) | |
batch_size = len(samples) | |
rollouts = make_experience(samples, rewards, tokenizer=cached_tokenizer(), verbose=False) | |
ilql_config: ILQLConfig = default_ilql_config().method | |
loader = rollouts.create_loader(batch_size) | |
ilql_batch = next(iter(loader)) | |
seq_len = ilql_batch.input_ids.shape[1] | |
heads = ILQLHeads(hidden_size, 50257, two_qs, alpha=1.0, dtype=torch.float32) | |
hidden_states = torch.randn(batch_size, seq_len, hidden_size) | |
qs, target_qs, vs = heads(hidden_states, states_ixs=ilql_batch.states_ixs, actions_ixs=ilql_batch.actions_ixs) | |
logits = torch.randn(batch_size, seq_len, 50257) | |
loss, stats = ilql_config.loss((logits, (qs, target_qs, vs)), ilql_batch) | |