from typing import Callable, Dict, Optional, Union, Tuple import copy import math import multiprocessing import os import torch import torch.nn as nn import transformers class ContextualModelConfig(transformers.configuration_utils.PretrainedConfig): """We create a dummy configuration class that will just set properties based on whatever kwargs we pass in. When this class is initialized (see experiments.py) we pass in the union of all data, model, and training args, all of which should get saved to the config json. """ def __init__(self, **kwargs): for key, value in kwargs.items(): try: json.dumps(value) setattr(self, key, value) except TypeError: # value was not JSON-serializable, skip continue super().__init__() def load_embedder_and_tokenizer(name: str) -> Tuple[ transformers.PreTrainedModel, transformers.PreTrainedTokenizer ]: if name.startswith("nomic") or (name == "bert-base-uncased"): model = transfromers.AutoModel.from_pretrained(name) tokenizer = transformers.AutoTokenizer.from_pretrained(name) elif name in ["gtr-base", "gtr_base"]: model = transformers.AutoModel.from_pretrained( "sentence-transformers/gtr-t5-base" ).encoder tokenizer = transformers.AutoTokenizer.from_pretrained( "sentence-transformers/gtr-t5-base" ) elif name == "pile-t5-base-encoder": model = transformers.AutoModel.from_pretrained( "EleutherAI/pile-t5-base" ).encoder tokenizer = transformers.AutoTokenizer.from_pretrained( "EleutherAI/pile-t5-base" ) tokenizer.pad_token = tokenizer.eos_token elif name == "pile-t5-base-decoder": model = transformers.AutoModel.from_pretrained( "EleutherAI/pile-t5-base" ).decoder tokenizer = transformers.AutoTokenizer.from_pretrained( "EleutherAI/pile-t5-base" ) tokenizer.pad_token = tokenizer.eos_token elif name.startswith("gpt2") or name.startswith("meta-llama") or ("Llama" in name): model = transformers.AutoModelForCausalLM.from_pretrained( name, # torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", low_cpu_mem_usage=True, # device_map="auto", ) model.padding_side = "right" tokenizer = transformers.AutoTokenizer.from_pretrained(name) tokenizer.pad_token = tokenizer.eos_token tokenizer.add_eos_token = True else: model = transformers.AutoModel.from_pretrained(name, trust_remote_code=True) tokenizer = transformers.AutoTokenizer.from_pretrained(name) # if use_bettertransformer: # from optimum.bettertransformer import BetterTransformer # model = BetterTransformer.transform(model) return model, tokenizer def get_world_size() -> int: try: return torch.distributed.get_world_size() except (RuntimeError, ValueError): return 1 def get_rank() -> int: try: return torch.distributed.get_rank() except (RuntimeError, ValueError): return 0 def gather(t: torch.Tensor) -> torch.Tensor: # torch.distributed.nn.all_gather scales by world size since the reduce op is SUM # https://github.com/pytorch/pytorch/issues/58005 # only should use torch.distributed.nn.all_gather if we implement a `local_loss` # like: https://github.com/mlfoundations/open_clip/issues/616 world_size = get_world_size() if world_size == 1: return t if t.ndim == 0: t = t.unsqueeze(0) gathered = [torch.empty_like(t) for _ in range(world_size)] torch.distributed.all_gather(gathered, t) gathered[get_rank()] = t return torch.cat(gathered, dim=0) def gather_sum(t: torch.Tensor) -> torch.Tensor: # torch.distributed.nn.all_gather scales by world size since the reduce op is SUM # https://github.com/pytorch/pytorch/issues/58005 # only should use torch.distributed.nn.all_gather if we implement a `local_loss` # like: https://github.com/mlfoundations/open_clip/issues/616 world_size = get_world_size() if world_size == 1: return t if t.ndim == 0: t = t.unsqueeze(0) gathered = [torch.empty_like(t) for _ in range(world_size)] torch.distributed.all_gather(gathered, t) gathered = torch.stack(gathered, dim=0) return gathered.sum(dim=0) # Sum across workers def get_num_proc() -> int: world_size: int = get_world_size() try: # os.sched_getaffinity respects schedulers, unlike cpu_count(), but it's only available # on some Unix platforms, so we support both! return len(os.sched_getaffinity(0)) // world_size # type: ignore[attr-defined] except AttributeError: return multiprocessing.cpu_count() // world_size def torch_main_worker_finish_first(func: Callable): def wrapper(*args, **kwargs): # Get local rank (need to support non-DDP). try: local_rank = torch.distributed.get_rank() ddp_enabled = True except (RuntimeError, ValueError): local_rank = -1 ddp_enabled = False is_main_worker = local_rank <= 0 # Run on main worker first. if is_main_worker: result = func(*args, **kwargs) # Then everyone waits. if ddp_enabled: torch.distributed.barrier() # Run on other workers now. if not is_main_worker: result = func(*args, **kwargs) # Now everyone waits again. if ddp_enabled: torch.distributed.barrier() return result return wrapper def print0(*args, **kwargs) -> None: if get_rank() == 0: print(*args, **kwargs) def verify_ddp_weights_equal(model: torch.nn.Module, atol: float = 1e-5) -> None: if hasattr(model, "module"): model = model.module world_size = get_world_size() if world_size > 8: print0(f"[verify_ddp_weights_equal] Skipping with world_size={world_size} ⚠️") return for name, param in model.named_parameters(): if param is None: continue if param.grad is None: print0(f"[verify_ddp_weights_equal] Skipping param [{name}] with no grad") continue gathered_param = gather(param).reshape((world_size, -1)) absolute_diffs = (gathered_param[None, 0, :] - gathered_param).abs() rank_params_eq = (absolute_diffs < atol).all() assert rank_params_eq, f"❌ param [{name}] not equal - got max_absolute_diff={absolute_diffs.max()}" ################################################################################################################### gathered_param_grad = gather(param.grad).reshape((world_size, -1)) absolute_grad_diffs = (gathered_param_grad[None, 0, :] - gathered_param_grad).abs() rank_grad_params_eq = (absolute_grad_diffs < atol).all() assert rank_grad_params_eq, f"❌ param [{name}] grad not equal - got max_absolute_diff={absolute_grad_diffs.max()}" ################################################################################################################### print0("[verify_ddp_weights_equal] Verified DDP parameter correctness ✅") def mean_pool_3d( hidden_states: torch.Tensor, attention_mask: torch.Tensor ) -> torch.Tensor: B, T, S, D = hidden_states.shape unmasked_outputs = hidden_states * attention_mask[..., None] pooled_outputs = unmasked_outputs.sum(dim=2) / (attention_mask.sum(dim=2)[..., None] + 1e-9) # fix for gradient flow: fill empty rows with the mean of the rest of the sequence sequence_means = ( hidden_states.reshape((B, S * T, D)) .mean(dim=1, keepdim=True) .expand(-1, T, -1) ) pooled_outputs = pooled_outputs.where( (attention_mask.sum(dim=2)[..., None] > 0), sequence_means ) assert pooled_outputs.shape == (B, T, D) return pooled_outputs def mean_pool( hidden_states: torch.Tensor, attention_mask: torch.Tensor ) -> torch.Tensor: B, _S, D = hidden_states.shape unmasked_outputs = hidden_states * attention_mask[..., None] pooled_outputs = unmasked_outputs.sum(dim=1) / (attention_mask.sum(dim=1)[:, None] + 1e-20) assert pooled_outputs.shape == (B, D) return pooled_outputs def mean_pool_weighted( hidden_states: torch.Tensor, attention_mask: torch.Tensor ) -> torch.Tensor: B, _S, D = hidden_states.shape attention_mask *= attention_mask.cumsum(dim=1) # [0,1,1,1,0,0] -> [0,1,2,3,0,0] s = torch.sum(hidden_states * attention_mask.unsqueeze(-1).float(), dim=1) d = attention_mask.sum(dim=1, keepdim=True).float() return s / d def slice_sparse_tensor_rows(t: torch.sparse.Tensor, min_row: int, max_row: int) -> torch.sparse.Tensor: assert min_row < max_row, f"can't slice from row {min_row} to {max_row}" t = t.coalesce() row_idxs = t.indices()[0] index_mask = (min_row <= row_idxs) & (row_idxs < max_row) num_rows = (max_row - min_row) num_cols = t.shape[1] idxs = t.indices()[:, index_mask] vals = t.values()[index_mask] return torch.sparse_coo_tensor(idxs, vals, size=(num_rows, num_cols)).coalesce() def slice_tensor_rows(t: torch.Tensor, min_row: int, max_row: int) -> torch.Tensor: if t.is_sparse: return slice_sparse_tensor_rows(t=t, min_row=min_row, max_row=max_row) else: return t[min_row:max_row] @torch.no_grad def maxsim( X: torch.Tensor, y: torch.Tensor, maximize: bool, chunk_size: int = 8_000, debug_mem_usage: bool = False) -> torch.Tensor: device = X.device n_samples = X.shape[0] max_sim_v = torch.zeros(n_samples, device=device, dtype=X.dtype) max_sim_i = torch.zeros(n_samples, device=device, dtype=torch.int64) # TODO: Implement faster max (without going to dense tensors). # TODO: Use multiple GPUs. rank = get_rank() world_size = get_world_size() worker_worklist_size = int(math.ceil(n_samples / world_size)) splits_start_idx = worker_worklist_size * rank splits_end_idx = worker_worklist_size * (rank + 1) for i in range(splits_start_idx, splits_end_idx, chunk_size): start, end = i, min(i + chunk_size, n_samples) sub_x = slice_tensor_rows(X, start, end) if debug_mem_usage: print(f"[maxsim] step {i} cuda mem free/total = {torch.cuda.mem_get_info()}") if debug_mem_usage: print("[maxsim] sub_x.shape:", sub_x.shape, "//", "y.shape:", y.shape) sub_sim = sub_x @ y # TODO – Implement sparse max here to save mem! sub_sim = sub_sim if maximize: sub_max_sim_v, sub_max_sim_i = sub_sim.to_dense().max(dim=-1) else: sub_max_sim_v, sub_max_sim_i = sub_sim.to_dense().min(dim=-1) del sub_sim del sub_x torch.cuda.empty_cache() # needs to happen after maxsim for some reason. max_sim_v[start: end] = sub_max_sim_v max_sim_i[start: end] = sub_max_sim_i # gather max_sim_v = gather_sum(max_sim_v) max_sim_i = gather_sum(max_sim_i) k = y.shape[1] assert max_sim_v.shape == (n_samples,) assert max_sim_i.shape == (n_samples,) assert max_sim_i.min() >= 0 assert max_sim_i.max() <= k return max_sim_v, max_sim_i def forward_batched( model: torch.nn.Module, input_ids: torch.Tensor, attention_mask: torch.Tensor, batch_size: int, dataset_input_ids: Optional[torch.Tensor] = None, dataset_attention_mask: Optional[torch.Tensor] = None, **second_stage_model_kwargs, ) -> torch.Tensor: if hasattr(model, "module"): model = model.module if hasattr(model, "first_stage_model"): # Support pooling over 3D dataset_input_ids inputs. if len(dataset_input_ids.shape) == 2: dataset_input_ids = dataset_input_ids[None] dataset_attention_mask = dataset_attention_mask[None] dataset_embeddings = [] for j in range(len(dataset_input_ids)): i = 0 dataset_embeddings_batch = [] while i < dataset_input_ids.shape[1]: dataset_embeddings_batch.append( model.first_stage_model( input_ids=dataset_input_ids[j][i:i+batch_size], attention_mask=dataset_attention_mask[j][i:i+batch_size], ) ) i += batch_size dataset_embeddings.append( torch.cat(dataset_embeddings_batch, dim=0) ) # Automatically pool over 3D dataset_input_ids. dataset_embeddings = torch.stack(dataset_embeddings, dim=0).mean(dim=0) j = 0 outputs = [] while j < len(input_ids): outputs.append( model.second_stage_model( input_ids=input_ids[j:j+batch_size], attention_mask=attention_mask[j:j+batch_size], dataset_embeddings=dataset_embeddings, **second_stage_model_kwargs, ) ) j += batch_size return torch.cat(outputs, dim=0) else: i = 0 outputs = [] while i < len(input_ids): # breakpoint() outputs.append( model( input_ids=input_ids[i:i+batch_size], attention_mask=attention_mask[i:i+batch_size], **second_stage_model_kwargs, ) ) i += batch_size return torch.cat(outputs, dim=0) def last_token_pool(hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: # https://github.com/ContextualAI/gritlm/blob/main/gritlm/gritlm.py#L190 b, n, d = hidden_state.size() # Get the last `1` in the attention mask of each item # Often it is just `gather_indices = torch.argmin(attention_mask, 1, keepdim=False) - 1` # except when 1) There's all 1's 2) There's 0's before the 1's reversed_mask = torch.flip(attention_mask, dims=(1,)) argmax_reverse = torch.argmax(reversed_mask, dim=1, keepdim=False) gather_indices = attention_mask.size(1) - argmax_reverse - 1 # If there are empty sequences, where the index would become -1 it will crash so set them to 0 gather_indices = torch.clamp(gather_indices, min=0) # Turn indices from shape [b] -> [b, 1, d] gather_indices = gather_indices.unsqueeze(-1).repeat(1, d) gather_indices = gather_indices.unsqueeze(1) assert gather_indices.shape == (b, 1, d) # Gather along the seq len: [b, n, d] -> [b, d] # Actually no need for the attention mask as we gather the last token where attn_mask=1 but # as some indices (which shouldn't be attended to) may be 0 due to clamp, use mask to ignore them again input_mask_expanded = attention_mask.unsqueeze(-1).expand((b, n, d)).float() return torch.gather(hidden_state * input_mask_expanded, 1, gather_indices).squeeze(dim=1) def print0(*args, **kwargs) -> None: if get_rank() == 0: print(*args, **kwargs) def limit_layers(model: transformers.PreTrainedModel, n_layers: int) -> None: if hasattr(model, 'transformer'): if hasattr(model.transformer, 'h'): # gpt2 model.transformer.h = model.transformer.h[:n_layers] else: model.transformer.layer = model.transformer.layer[:n_layers] elif hasattr(model, 'encoder'): if hasattr(model.encoder, 'layers'): model.encoder.layers = model.encoder.layers[:n_layers] else: model.encoder.layer = model.encoder.layer[:n_layers] else: raise RuntimeError(f"unknown how to limit layers of model {type(model)}") def disable_dropout(model: torch.nn.Module): dropout_modules = [m for m in model.modules() if isinstance(m, torch.nn.Dropout)] for m in dropout_modules: m.p = 0.0 print0( f"Disabled {len(dropout_modules)} dropout modules from model type {type(model)}" ) def disable_causality(model: torch.nn.Module): disabled_modules = 0 for m in model.modules(): if hasattr(m, "is_causal"): m.is_causal = False disabled_modules += 1 print0( f"Set is_causal=False in {disabled_modules} modules from model type {type(model)}" ) class ContextualModelMixin(nn.Module): @property def num_corpus_tokens(self) -> int: return self.transductive_corpus_size * self.transductive_tokens_per_document def contextual_init(self): self.n_soft_prompt = 8 self.prompt_projection = torch.nn.Sequential( torch.nn.Linear(self.hidden_size, self.hidden_size), torch.nn.ReLU(), torch.nn.Linear(self.hidden_size, self.hidden_size * self.n_soft_prompt) ) self.transductive_corpus_size = vars(self.config).get("transductive_corpus_size", 1) self.transductive_tokens_per_document = vars(self.config).get("transductive_tokens_per_document", 1) self.randomize_dataset_sequence_order = True self.sequence_dropout_prob = vars(self.config).get("transductive_sequence_dropout_prob", 0.0) if self.sequence_dropout_prob > 0.0: self.sequence_dropout_null_embedding = torch.nn.Parameter( torch.randn(self.hidden_size) * 0.01, requires_grad = True ) self.output_projection = torch.nn.Sequential( torch.nn.Linear(self.hidden_size, self.hidden_size), torch.nn.ReLU(), torch.nn.Linear(self.hidden_size, self.hidden_size) ) def _prepare_dataset_embeddings( self, input_ids: torch.Tensor, dataset_embeddings: torch.Tensor, null_dataset_embedding: bool = False, ) -> torch.Tensor: if not isinstance(dataset_embeddings, torch.Tensor): dataset_embeddings = torch.tensor(dataset_embeddings) if len(dataset_embeddings.shape) == 2: # Auto-expand for a batch. dataset_embeddings = dataset_embeddings[None, :, :] # (b, d) -> (1, b, d) dataset_embeddings = dataset_embeddings.to(input_ids.device) batch_size = input_ids.shape[0] if (self.transductive_tokens_per_document > 1): if self.training: # Choose N random documents to fill our context window with. # This logic is a little confusing but allows us to sample a # different batch *per-document* assert dataset_embeddings.shape[1] == self.transductive_tokens_per_document R = torch.randint( low=0, high=len(dataset_embeddings), size=(batch_size, self.config.transductive_corpus_size), device=dataset_embeddings.device ) # TODO make this deterministic somehow for evaluation? dataset_embeddings = dataset_embeddings[R].reshape((batch_size, self.num_corpus_tokens, self.hidden_size)) else: dataset_embeddings = dataset_embeddings.reshape((1, self.num_corpus_tokens, self.hidden_size)) # print("reshaped to dataset_embeddings.shape =", dataset_embeddings.shape) if dataset_embeddings.shape[1] > self.num_corpus_tokens: # If too many dataset embeddings are passed in, just take the first N until # we have the proper number. dataset_embeddings = dataset_embeddings[:, :self.num_corpus_tokens, :] _, corpus_size, _hidden_size = dataset_embeddings.shape if _ == 1: # Auto-expand for a batch. dataset_embeddings = dataset_embeddings.expand((batch_size, -1, -1)) if self.training and self.sequence_dropout_prob > 0.0: sequence_dropout_mask = ( torch.rand((batch_size, corpus_size), device=dataset_embeddings.device) < self.sequence_dropout_prob ) null_embeddings = self.sequence_dropout_null_embedding[None, None].expand(batch_size, corpus_size, -1) dataset_embeddings = torch.where( sequence_dropout_mask[..., None], null_embeddings, dataset_embeddings ) elif null_dataset_embedding: null_embeddings = self.sequence_dropout_null_embedding[None, None].expand(batch_size, corpus_size, -1) dataset_embeddings = null_embeddings # print(f"[ContextualModelMixin] dataset_embeddings.shape = {dataset_embeddings.shape}") # backbone_max_seq_length = self.backbone.config.max_trained_positions # assert batch_size + (2 * self.n_soft_prompt + corpus_size) <= backbone_max_seq_length, "too many hard negatives for backbone model" soft_prompt = torch.ones((1, self.hidden_size), device=dataset_embeddings.device, dtype=dataset_embeddings.dtype) soft_prompt = self.prompt_projection(soft_prompt).reshape((1, self.n_soft_prompt, self.hidden_size)) soft_prompt = soft_prompt.expand((len(dataset_embeddings), -1, -1)) # -> (b, 4+b, d) # soft_prompt.repeat((len(input_ids), 1, 1)) soft_prompt = torch.cat((dataset_embeddings, soft_prompt), dim=1) # print(f"[ContextualModelMixin] soft_prompt.shape = {soft_prompt.shape}") if self.training and self.randomize_dataset_sequence_order: randomized_order = torch.stack( [ torch.cat( ( torch.randperm(corpus_size, device=soft_prompt.device), torch.arange(self.n_soft_prompt, device=soft_prompt.device) + corpus_size ), dim=0) for _ in range(batch_size)]) randomized_order = randomized_order.to(soft_prompt.device) soft_prompt = soft_prompt.gather(1, randomized_order[..., None].expand_as(soft_prompt)) return soft_prompt class BiEncoder(transformers.PreTrainedModel): embedder: transformers.PreTrainedModel def __init__( self, config, #: transformers.PreTrainedConfig, ): super().__init__(config=config) embedder, _ = load_embedder_and_tokenizer( config.embedder, ) if config.limit_layers: print0(f"Limiting layers to {config.limit_layers}") limit_layers(embedder, config.limit_layers) self.embedder = embedder # if ("t5" in embedder.config.model_type): # print0(f"using torch.compile() on embedder of type `{embedder.config.model_type}`") # self.embedder = torch.compile(self.embedder) self.hidden_size = self.embedder.config.hidden_size # Allow pooling to multiple tokens per document self.transductive_tokens_per_document = vars(self.config).get("transductive_tokens_per_document", 1) self.mlp = torch.nn.Sequential( torch.nn.Linear(self.hidden_size, self.hidden_size), torch.nn.GELU(), torch.nn.Linear(self.hidden_size, self.config.embedding_output_dim or self.hidden_size), ) self.temp = config.logit_scale if config.disable_dropout: disable_dropout(self) self.pooling_strategy = vars(config).get("pooling_strategy", "mean") def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, dataset_input_ids: Optional[torch.Tensor] = None, dataset_attention_mask: Optional[torch.Tensor] = None, token_type_ids = None, output_hidden_states: bool = False, ) -> torch.Tensor: """ query_embedding (float torch.Tensor) - shape (batch_size, embedding_dim) document_embeddings (float torch.Tensor) - shape (corpus_size, embedding_dim) where the corpus_size >= batch_size and is structured like this: [d1, d2, d3, hn1_1, hn1_2, hn2_1, hn2_2, hn3_1, hn3_2] for a corpus with three documents and two hard negatives per document """ # del dataset_input_ids # del dataset_attention_mask del token_type_ids # from cde.lib.dist import get_rank # tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased") # if get_rank() == 0: # breakpoint() # torch.distributed.barrier() outputs = ( self.embedder( input_ids=input_ids, attention_mask=attention_mask, ).last_hidden_state ) if self.transductive_tokens_per_document > 1: document_embeddings = None batch_size, seq_length, output_dim = outputs.shape if seq_length % self.transductive_tokens_per_document != 0: # Pad to nearest multiple n_extra_embeds = self.transductive_tokens_per_document - (seq_length % self.transductive_tokens_per_document) outputs = torch.cat( (outputs, torch.zeros((batch_size, n_extra_embeds, output_dim), device=outputs.device)), dim=1 ) attention_mask = torch.cat( (attention_mask, torch.zeros((batch_size, n_extra_embeds), device=attention_mask.device)), dim=1 ) seq_length += n_extra_embeds print(f"Added {n_extra_embeds} padding tokens to input_ids and attention_mask") # print("ftransductive_tokens_per_document {self.transductive_tokens_per_document} outputs.shape =", outputs.shape) outputs = outputs.reshape( (batch_size, self.transductive_tokens_per_document, seq_length // self.transductive_tokens_per_document, output_dim) ) attention_mask = attention_mask.reshape((batch_size, self.transductive_tokens_per_document, -1)) document_embeddings = mean_pool_3d(outputs, attention_mask) document_embeddings = document_embeddings.reshape((batch_size, self.transductive_tokens_per_document, output_dim)) else: if self.pooling_strategy == "mean": document_embeddings = mean_pool(outputs, attention_mask) else: document_embeddings = document_embeddings.max(dim=1) output = self.mlp(document_embeddings) if output_hidden_states: return { "hidden_states": outputs, "pooled": output, } else: return output class DatasetConditionedAutoregressive(transformers.PreTrainedModel, ContextualModelMixin): def __init__( self, config, dataset_backbone: transformers.PreTrainedModel, first_stage_hidden_size: int, ): super().__init__(config=config) self.backbone = dataset_backbone self.backbone_hidden_size = self.backbone.config.hidden_size self.hidden_size = first_stage_hidden_size # Input token size self.contextual_init() disable_causality(self.backbone) self.input_ln = torch.nn.LayerNorm( self.backbone_hidden_size, eps=1e-5 ) # Override contextual init self.output_projection = torch.nn.Sequential( torch.nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size), torch.nn.ReLU(), torch.nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size) ) self._shift_rotary_embedding() @property def num_corpus_tokens(self) -> int: return self.config.transductive_corpus_size * self.transductive_tokens_per_document @property def corpus_token_ratio(self) -> float: # How many tokens from the first stage make one token in the second # stage? return self.backbone_hidden_size / self.hidden_size def corpus_token_pad_size(self, n_tokens: int) -> int: return self.hidden_size % self.backbone_hidden_size def _shift_rotary_embedding(self) -> None: disable_transductive_rotary_embedding = vars(self.config).get("disable_transductive_rotary_embedding", True) # TODO: Can we do this for LLAMA? print("Warning: Positional embedding disabling not implemented for LLAMA.") def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, dataset_embeddings: torch.Tensor, output_hidden_states: bool = False, null_dataset_embedding: bool = False, ) -> torch.Tensor: soft_prompt = self._prepare_dataset_embeddings( input_ids=input_ids, dataset_embeddings=dataset_embeddings, null_dataset_embedding=null_dataset_embedding, ) # Reshape for this model. # print("[DatasetConditionedAutoregressive] 1 -> soft_prompt.shape =", soft_prompt.shape) num_soft_elements = torch.prod(torch.tensor(soft_prompt.shape[1:])).item() soft_prompt = soft_prompt.reshape((soft_prompt.shape[0], num_soft_elements)) num_padding_elements = self.backbone_hidden_size - (num_soft_elements % self.backbone_hidden_size) padding = torch.ones((soft_prompt.shape[0], num_padding_elements), device=soft_prompt.device) soft_prompt = torch.cat((soft_prompt, padding), dim=1) soft_prompt = soft_prompt.reshape( (soft_prompt.shape[0], -1, self.backbone_hidden_size) ) soft_prompt = self.input_ln(soft_prompt) # print("[DatasetConditionedAutoregressive] 2 -> soft_prompt.shape =", soft_prompt.shape) backbone_attention_mask = torch.ones( soft_prompt.shape[0:2], dtype=torch.long, device=soft_prompt.device, ) token_embeddings = self.backbone.get_input_embeddings() inputs_embeds = token_embeddings(input_ids) # (b, s) -> (b, s, d) # print("[2] inputs_embeds.shape =", inputs_embeds.shape) inputs_embeds = torch.cat((soft_prompt, inputs_embeds), dim=1) # (v, 4+b+s, d) # print("[3.a] inputs_embeds.shape =", inputs_embeds.shape) input_attention_mask = torch.cat((backbone_attention_mask, attention_mask), dim=1) # print("[3.b] attention_mask.shape =", attention_mask.shape) output = self.backbone( inputs_embeds=inputs_embeds, attention_mask=input_attention_mask, output_hidden_states=True, ) # (1, 4 + b + s, d) # trim soft prompt last_hidden_state = output.hidden_states[-1] n_soft_prompt_tokens = soft_prompt.shape[1] output_vectors = last_hidden_state[:, n_soft_prompt_tokens:, :] output_attention_mask = input_attention_mask[:, n_soft_prompt_tokens:] # Take last token position if vars(self.config).get("pooling_strategy") == "last_token": output_pooled = last_token_pool(output_vectors, output_attention_mask) elif vars(self.config).get("pooling_strategy") == "mean": output_pooled = mean_pool(output_vectors, output_attention_mask) else: output_pooled = mean_pool_weighted(output_vectors, output_attention_mask) # average with original vectors # TODO: Argparse for pooling strategy. output = self.output_projection(output_pooled) # (b, 2d) -> (b, d) if output_hidden_states: return { "hidden_states": output_vectors, "pooled": output, } else: return output class DatasetConditionedBiencoder(transformers.PreTrainedModel, ContextualModelMixin): def __init__( self, config, dataset_backbone: transformers.PreTrainedModel, ): super().__init__(config=config) self.backbone = dataset_backbone self.hidden_size = self.backbone.config.hidden_size self.hidden_size = dataset_backbone.config.hidden_size # self.input_ln = torch.nn.LayerNorm( # self.hidden_size, # eps=self.backbone.config.layer_norm_epsilon # ) self.contextual_init() self._shift_rotary_embedding() @property def num_corpus_tokens(self) -> int: return self.config.transductive_corpus_size * self.transductive_tokens_per_document def _shift_rotary_embedding(self) -> None: disable_transductive_rotary_embedding = vars(self.config).get("disable_transductive_rotary_embedding", True) if self.backbone.config.model_type.startswith("nomic") and disable_transductive_rotary_embedding: # We only want to apply positional embeddings to the # *text* portion of the backbone network. self.backbone.config.rotary_start_pos = 0.0 rotary_disabled = 0 rotary_start_pos = self.num_corpus_tokens for module in self.backbone.modules(): if hasattr(module, "rotary_emb_dim"): module.rotary_start_pos = rotary_start_pos rotary_disabled += 1 print0(f"modified {rotary_disabled} rotary modules – set rotary_start_pos to {rotary_start_pos}") def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, dataset_embeddings: torch.Tensor, output_hidden_states: bool = False, null_dataset_embedding: bool = False, ) -> torch.Tensor: # print(f"[DatasetConditionedBiencoder - 0] input_ids.shape => {input_ids.shape} // dataset_embeddings.shape =", dataset_embeddings.shape) soft_prompt = self._prepare_dataset_embeddings( input_ids=input_ids, dataset_embeddings=dataset_embeddings, null_dataset_embedding=null_dataset_embedding, ) # print(f"[DatasetConditionedBiencoder - 1] soft_prompt.shape => {soft_prompt.shape}") backbone_attention_mask = torch.ones( soft_prompt.shape[0:2], dtype=torch.long, device=soft_prompt.device, ) inputs_embeds = self.backbone.embeddings(input_ids) # (b, s) -> (b, s, d) # print("[2] inputs_embeds.shape =", inputs_embeds.shape) inputs_embeds = torch.cat((soft_prompt, inputs_embeds), dim=1) # (v, 4+b+s, d) # print("[3.a] inputs_embeds.shape =", inputs_embeds.shape) attention_mask = torch.cat((backbone_attention_mask, attention_mask), dim=1) # print("[3.b] attention_mask.shape =", attention_mask.shape) output = self.backbone( inputs_embeds=inputs_embeds, attention_mask=attention_mask, ) # (1, 4 + b + s, d) # trim soft prompt output_vectors = output.last_hidden_state # use only these tokens n_soft_prompt_tokens = soft_prompt.shape[1] # print("n_soft_prompt_tokens =", n_soft_prompt_tokens) output_vectors = output.last_hidden_state[:, n_soft_prompt_tokens:, :] output_attention_mask = attention_mask[:, n_soft_prompt_tokens:] # print("pooling output_vectors.shape =", output_vectors.shape, "and output_attention_mask.shape =", output_attention_mask.shape) output_pooled = mean_pool(output_vectors, output_attention_mask) # average with original vectors # TODO: Argparse for pooling strategy. # output_vectors = torch.cat((soft_prompt_pooled, output_pooled), dim=1) # (b, d) + (b, d) -> (b, 2d) # print("output_pooled.shape =", output_pooled.shape) output = self.output_projection(output_pooled) # (b, 2d) -> (b, d) # print("returning output.shape =", output.shape) if output_hidden_states: return { "hidden_states": output_vectors, "pooled": output, } else: return output class DatasetPrefixBiencoder(transformers.PreTrainedModel, ContextualModelMixin): def __init__( self, config, #: transformers.PreTrainedConfig, embedder: transformers.PreTrainedModel, ): super().__init__(config=config) self.embedder = embedder self.hidden_size = self.embedder.config.hidden_size self.contextual_init() def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, dataset_input_ids: torch.Tensor, dataset_attention_mask: torch.Tensor, output_hidden_states: bool = False, ) -> torch.Tensor: R = torch.randint(low=0, high=len(dataset_input_ids), size=(len(input_ids),), device=dataset_input_ids.device) dataset_input_ids = dataset_input_ids[R] input_ids = torch.cat((dataset_input_ids, input_ids), dim=1) dataset_attention_mask = torch.ones_like(dataset_attention_mask, device=dataset_attention_mask.device) input_attention_mask = torch.cat((dataset_attention_mask, attention_mask), dim=1) output_attention_mask = torch.cat( (torch.zeros_like(dataset_input_ids), attention_mask), dim=1 ) output = self.embedder( input_ids=input_ids, attention_mask=input_attention_mask, ) output_vectors = output.last_hidden_state output_pooled = mean_pool(output_vectors, output_attention_mask) output = self.output_projection(output_pooled) # (b, 2d) -> (b, d) if output_hidden_states: S_d = dataset_attention_mask.shape[1] output_vectors = output_vectors[:, S_d:, :] return { "hidden_states": output_vectors, "pooled": output, } else: return output class DatasetTransformer(transformers.PreTrainedModel): config_class = ContextualModelConfig embedder: transformers.PreTrainedModel dataset_backbone: transformers.PreTrainedModel def __init__( self, config, ): super().__init__(config=config) dataset_backbone, _ = load_embedder_and_tokenizer( vars(config).get("dataset_backbone", config.embedder) ) if config.limit_layers: print0(f"Limiting layers to {config.limit_layers}") limit_layers(dataset_backbone, config.limit_layers) biencoder_config = copy.deepcopy(config) biencoder_config.embedding_output_dim = None biencoder_config.limit_layers = vars(self.config).get("limit_layers_first_stage", None) self.first_stage_model = BiEncoder( config=biencoder_config, ) if vars(config).get("autoregressive_backbone", False): self.second_stage_model = DatasetConditionedAutoregressive( config=config, dataset_backbone=dataset_backbone, first_stage_hidden_size=self.first_stage_model.hidden_size, ) else: self.second_stage_model = DatasetConditionedBiencoder( config=config, dataset_backbone=dataset_backbone ) self.temp = config.logit_scale if config.disable_dropout: disable_dropout(self) transductive_tie_token_embeddings = vars(self.config).get("transductive_tie_token_embeddings", False) if transductive_tie_token_embeddings: self.second_stage_model.backbone.embeddings.word_embeddings.weight = ( self.first_stage_model.embedder.embeddings.word_embeddings.weight ) def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, dataset_input_ids: Optional[torch.Tensor], dataset_attention_mask: Optional[torch.Tensor], output_hidden_states: bool = False, ) -> torch.Tensor: """ input_ids (long torch.Tensor) – ids of input tokens attention_mask (bool torch.Tensor) """ dataset_embeddings = self.first_stage_model( input_ids=dataset_input_ids, attention_mask=dataset_attention_mask ) return self.second_stage_model( input_ids=input_ids, attention_mask=attention_mask, dataset_embeddings=dataset_embeddings, output_hidden_states=output_hidden_states, ) def get_model_class(name: str): if name in 'transductive': return DatasetTransformer elif name == 'biencoder': return BiEncoder elif name == "dataset_prefix_biencoder": return DatasetPrefixBiencoder else: raise ValueError(f'unknown model cls {name}')