from typing import Callable, Dict, Optional, Union, Tuple import copy import math import multiprocessing import os import torch import torch.nn as nn import transformers from .misc import ContextualModelConfig def load_embedder_and_tokenizer(name: str) -> Tuple[ transformers.PreTrainedModel, transformers.PreTrainedTokenizer ]: if name.startswith("nomic") or (name == "bert-base-uncased"): model = transformers.AutoModelForMaskedLM.from_pretrained(name, trust_remote_code=True).bert tokenizer = transformers.AutoTokenizer.from_pretrained(name) elif name in ["gtr-base", "gtr_base"]: model = transformers.AutoModel.from_pretrained( "sentence-transformers/gtr-t5-base" ).encoder tokenizer = transformers.AutoTokenizer.from_pretrained( "sentence-transformers/gtr-t5-base" ) elif name == "pile-t5-base-encoder": model = transformers.AutoModel.from_pretrained( "EleutherAI/pile-t5-base" ).encoder tokenizer = transformers.AutoTokenizer.from_pretrained( "EleutherAI/pile-t5-base" ) tokenizer.pad_token = tokenizer.eos_token elif name == "pile-t5-base-decoder": model = transformers.AutoModel.from_pretrained( "EleutherAI/pile-t5-base" ).decoder tokenizer = transformers.AutoTokenizer.from_pretrained( "EleutherAI/pile-t5-base" ) tokenizer.pad_token = tokenizer.eos_token elif name.startswith("gpt2") or name.startswith("meta-llama") or ("Llama" in name): model = transformers.AutoModelForCausalLM.from_pretrained( name, # torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", low_cpu_mem_usage=True, # device_map="auto", ) model.padding_side = "right" tokenizer = transformers.AutoTokenizer.from_pretrained(name) tokenizer.pad_token = tokenizer.eos_token tokenizer.add_eos_token = True else: model = transformers.AutoModel.from_pretrained(name, trust_remote_code=True) tokenizer = transformers.AutoTokenizer.from_pretrained(name) # if use_bettertransformer: # from optimum.bettertransformer import BetterTransformer # model = BetterTransformer.transform(model) return model, tokenizer def get_world_size() -> int: try: return torch.distributed.get_world_size() except (RuntimeError, ValueError): return 1 def get_rank() -> int: try: return torch.distributed.get_rank() except (RuntimeError, ValueError): return 0 def gather(t: torch.Tensor) -> torch.Tensor: # torch.distributed.nn.all_gather scales by world size since the reduce op is SUM # https://github.com/pytorch/pytorch/issues/58005 # only should use torch.distributed.nn.all_gather if we implement a `local_loss` # like: https://github.com/mlfoundations/open_clip/issues/616 world_size = get_world_size() if world_size == 1: return t if t.ndim == 0: t = t.unsqueeze(0) gathered = [torch.empty_like(t) for _ in range(world_size)] torch.distributed.all_gather(gathered, t) gathered[get_rank()] = t return torch.cat(gathered, dim=0) def gather_sum(t: torch.Tensor) -> torch.Tensor: # torch.distributed.nn.all_gather scales by world size since the reduce op is SUM # https://github.com/pytorch/pytorch/issues/58005 # only should use torch.distributed.nn.all_gather if we implement a `local_loss` # like: https://github.com/mlfoundations/open_clip/issues/616 world_size = get_world_size() if world_size == 1: return t if t.ndim == 0: t = t.unsqueeze(0) gathered = [torch.empty_like(t) for _ in range(world_size)] torch.distributed.all_gather(gathered, t) gathered = torch.stack(gathered, dim=0) return gathered.sum(dim=0) # Sum across workers def get_num_proc() -> int: world_size: int = get_world_size() try: # os.sched_getaffinity respects schedulers, unlike cpu_count(), but it's only available # on some Unix platforms, so we support both! return len(os.sched_getaffinity(0)) // world_size # type: ignore[attr-defined] except AttributeError: return multiprocessing.cpu_count() // world_size def torch_main_worker_finish_first(func: Callable): def wrapper(*args, **kwargs): # Get local rank (need to support non-DDP). try: local_rank = torch.distributed.get_rank() ddp_enabled = True except (RuntimeError, ValueError): local_rank = -1 ddp_enabled = False is_main_worker = local_rank <= 0 # Run on main worker first. if is_main_worker: result = func(*args, **kwargs) # Then everyone waits. if ddp_enabled: torch.distributed.barrier() # Run on other workers now. if not is_main_worker: result = func(*args, **kwargs) # Now everyone waits again. if ddp_enabled: torch.distributed.barrier() return result return wrapper def print0(*args, **kwargs) -> None: if get_rank() == 0: print(*args, **kwargs) def verify_ddp_weights_equal(model: torch.nn.Module, atol: float = 1e-5) -> None: if hasattr(model, "module"): model = model.module world_size = get_world_size() if world_size > 8: print0(f"[verify_ddp_weights_equal] Skipping with world_size={world_size} ⚠️") return for name, param in model.named_parameters(): if param is None: continue if param.grad is None: print0(f"[verify_ddp_weights_equal] Skipping param [{name}] with no grad") continue gathered_param = gather(param).reshape((world_size, -1)) absolute_diffs = (gathered_param[None, 0, :] - gathered_param).abs() rank_params_eq = (absolute_diffs < atol).all() assert rank_params_eq, f"❌ param [{name}] not equal - got max_absolute_diff={absolute_diffs.max()}" ################################################################################################################### gathered_param_grad = gather(param.grad).reshape((world_size, -1)) absolute_grad_diffs = (gathered_param_grad[None, 0, :] - gathered_param_grad).abs() rank_grad_params_eq = (absolute_grad_diffs < atol).all() assert rank_grad_params_eq, f"❌ param [{name}] grad not equal - got max_absolute_diff={absolute_grad_diffs.max()}" ################################################################################################################### print0("[verify_ddp_weights_equal] Verified DDP parameter correctness ✅") def mean_pool_3d( hidden_states: torch.Tensor, attention_mask: torch.Tensor ) -> torch.Tensor: B, T, S, D = hidden_states.shape unmasked_outputs = hidden_states * attention_mask[..., None] pooled_outputs = unmasked_outputs.sum(dim=2) / (attention_mask.sum(dim=2)[..., None] + 1e-9) # fix for gradient flow: fill empty rows with the mean of the rest of the sequence sequence_means = ( hidden_states.reshape((B, S * T, D)) .mean(dim=1, keepdim=True) .expand(-1, T, -1) ) pooled_outputs = pooled_outputs.where( (attention_mask.sum(dim=2)[..., None] > 0), sequence_means ) assert pooled_outputs.shape == (B, T, D) return pooled_outputs def mean_pool( hidden_states: torch.Tensor, attention_mask: torch.Tensor ) -> torch.Tensor: B, _S, D = hidden_states.shape unmasked_outputs = hidden_states * attention_mask[..., None] pooled_outputs = unmasked_outputs.sum(dim=1) / (attention_mask.sum(dim=1)[:, None] + 1e-20) assert pooled_outputs.shape == (B, D) return pooled_outputs def mean_pool_weighted( hidden_states: torch.Tensor, attention_mask: torch.Tensor ) -> torch.Tensor: B, _S, D = hidden_states.shape attention_mask *= attention_mask.cumsum(dim=1) # [0,1,1,1,0,0] -> [0,1,2,3,0,0] s = torch.sum(hidden_states * attention_mask.unsqueeze(-1).float(), dim=1) d = attention_mask.sum(dim=1, keepdim=True).float() return s / d def slice_sparse_tensor_rows(t: torch.sparse.Tensor, min_row: int, max_row: int) -> torch.sparse.Tensor: assert min_row < max_row, f"can't slice from row {min_row} to {max_row}" t = t.coalesce() row_idxs = t.indices()[0] index_mask = (min_row <= row_idxs) & (row_idxs < max_row) num_rows = (max_row - min_row) num_cols = t.shape[1] idxs = t.indices()[:, index_mask] vals = t.values()[index_mask] return torch.sparse_coo_tensor(idxs, vals, size=(num_rows, num_cols)).coalesce() def slice_tensor_rows(t: torch.Tensor, min_row: int, max_row: int) -> torch.Tensor: if t.is_sparse: return slice_sparse_tensor_rows(t=t, min_row=min_row, max_row=max_row) else: return t[min_row:max_row] @torch.no_grad def maxsim( X: torch.Tensor, y: torch.Tensor, maximize: bool, chunk_size: int = 8_000, debug_mem_usage: bool = False) -> torch.Tensor: device = X.device n_samples = X.shape[0] max_sim_v = torch.zeros(n_samples, device=device, dtype=X.dtype) max_sim_i = torch.zeros(n_samples, device=device, dtype=torch.int64) # TODO: Implement faster max (without going to dense tensors). # TODO: Use multiple GPUs. rank = get_rank() world_size = get_world_size() worker_worklist_size = int(math.ceil(n_samples / world_size)) splits_start_idx = worker_worklist_size * rank splits_end_idx = worker_worklist_size * (rank + 1) for i in range(splits_start_idx, splits_end_idx, chunk_size): start, end = i, min(i + chunk_size, n_samples) sub_x = slice_tensor_rows(X, start, end) if debug_mem_usage: print(f"[maxsim] step {i} cuda mem free/total = {torch.cuda.mem_get_info()}") if debug_mem_usage: print("[maxsim] sub_x.shape:", sub_x.shape, "//", "y.shape:", y.shape) sub_sim = sub_x @ y # TODO – Implement sparse max here to save mem! sub_sim = sub_sim if maximize: sub_max_sim_v, sub_max_sim_i = sub_sim.to_dense().max(dim=-1) else: sub_max_sim_v, sub_max_sim_i = sub_sim.to_dense().min(dim=-1) del sub_sim del sub_x torch.cuda.empty_cache() # needs to happen after maxsim for some reason. max_sim_v[start: end] = sub_max_sim_v max_sim_i[start: end] = sub_max_sim_i # gather max_sim_v = gather_sum(max_sim_v) max_sim_i = gather_sum(max_sim_i) k = y.shape[1] assert max_sim_v.shape == (n_samples,) assert max_sim_i.shape == (n_samples,) assert max_sim_i.min() >= 0 assert max_sim_i.max() <= k return max_sim_v, max_sim_i def forward_batched( model: torch.nn.Module, input_ids: torch.Tensor, attention_mask: torch.Tensor, batch_size: int, dataset_input_ids: Optional[torch.Tensor] = None, dataset_attention_mask: Optional[torch.Tensor] = None, **second_stage_model_kwargs, ) -> torch.Tensor: if hasattr(model, "module"): model = model.module if hasattr(model, "first_stage_model"): # Support pooling over 3D dataset_input_ids inputs. if len(dataset_input_ids.shape) == 2: dataset_input_ids = dataset_input_ids[None] dataset_attention_mask = dataset_attention_mask[None] dataset_embeddings = [] for j in range(len(dataset_input_ids)): i = 0 dataset_embeddings_batch = [] while i < dataset_input_ids.shape[1]: dataset_embeddings_batch.append( model.first_stage_model( input_ids=dataset_input_ids[j][i:i+batch_size], attention_mask=dataset_attention_mask[j][i:i+batch_size], ) ) i += batch_size dataset_embeddings.append( torch.cat(dataset_embeddings_batch, dim=0) ) # Automatically pool over 3D dataset_input_ids. dataset_embeddings = torch.stack(dataset_embeddings, dim=0).mean(dim=0) j = 0 outputs = [] while j < len(input_ids): outputs.append( model.second_stage_model( input_ids=input_ids[j:j+batch_size], attention_mask=attention_mask[j:j+batch_size], dataset_embeddings=dataset_embeddings, **second_stage_model_kwargs, ) ) j += batch_size return torch.cat(outputs, dim=0) else: i = 0 outputs = [] while i < len(input_ids): # breakpoint() outputs.append( model( input_ids=input_ids[i:i+batch_size], attention_mask=attention_mask[i:i+batch_size], **second_stage_model_kwargs, ) ) i += batch_size return torch.cat(outputs, dim=0) def last_token_pool(hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: # https://github.com/ContextualAI/gritlm/blob/main/gritlm/gritlm.py#L190 b, n, d = hidden_state.size() # Get the last `1` in the attention mask of each item # Often it is just `gather_indices = torch.argmin(attention_mask, 1, keepdim=False) - 1` # except when 1) There's all 1's 2) There's 0's before the 1's reversed_mask = torch.flip(attention_mask, dims=(1,)) argmax_reverse = torch.argmax(reversed_mask, dim=1, keepdim=False) gather_indices = attention_mask.size(1) - argmax_reverse - 1 # If there are empty sequences, where the index would become -1 it will crash so set them to 0 gather_indices = torch.clamp(gather_indices, min=0) # Turn indices from shape [b] -> [b, 1, d] gather_indices = gather_indices.unsqueeze(-1).repeat(1, d) gather_indices = gather_indices.unsqueeze(1) assert gather_indices.shape == (b, 1, d) # Gather along the seq len: [b, n, d] -> [b, d] # Actually no need for the attention mask as we gather the last token where attn_mask=1 but # as some indices (which shouldn't be attended to) may be 0 due to clamp, use mask to ignore them again input_mask_expanded = attention_mask.unsqueeze(-1).expand((b, n, d)).float() return torch.gather(hidden_state * input_mask_expanded, 1, gather_indices).squeeze(dim=1) def print0(*args, **kwargs) -> None: if get_rank() == 0: print(*args, **kwargs) def limit_layers(model: transformers.PreTrainedModel, n_layers: int) -> None: if hasattr(model, 'transformer'): if hasattr(model.transformer, 'h'): # gpt2 model.transformer.h = model.transformer.h[:n_layers] else: model.transformer.layer = model.transformer.layer[:n_layers] elif hasattr(model, 'encoder'): if hasattr(model.encoder, 'layers'): model.encoder.layers = model.encoder.layers[:n_layers] else: model.encoder.layer = model.encoder.layer[:n_layers] else: raise RuntimeError(f"unknown how to limit layers of model {type(model)}") def disable_dropout(model: torch.nn.Module): dropout_modules = [m for m in model.modules() if isinstance(m, torch.nn.Dropout)] for m in dropout_modules: m.p = 0.0 print0( f"Disabled {len(dropout_modules)} dropout modules from model type {type(model)}" ) def disable_causality(model: torch.nn.Module): disabled_modules = 0 for m in model.modules(): if hasattr(m, "is_causal"): m.is_causal = False disabled_modules += 1 print0( f"Set is_causal=False in {disabled_modules} modules from model type {type(model)}" ) class ContextualModelMixin(nn.Module): @property def num_corpus_tokens(self) -> int: return self.transductive_corpus_size * self.transductive_tokens_per_document def contextual_init(self): self.n_soft_prompt = 8 self.prompt_projection = torch.nn.Sequential( torch.nn.Linear(self.hidden_size, self.hidden_size), torch.nn.ReLU(), torch.nn.Linear(self.hidden_size, self.hidden_size * self.n_soft_prompt) ) self.transductive_corpus_size = vars(self.config).get("transductive_corpus_size", 1) self.transductive_tokens_per_document = vars(self.config).get("transductive_tokens_per_document", 1) self.randomize_dataset_sequence_order = True self.sequence_dropout_prob = vars(self.config).get("transductive_sequence_dropout_prob", 0.0) if self.sequence_dropout_prob > 0.0: self.sequence_dropout_null_embedding = torch.nn.Parameter( torch.randn(self.hidden_size) * 0.01, requires_grad = True ) self.output_projection = torch.nn.Sequential( torch.nn.Linear(self.hidden_size, self.hidden_size), torch.nn.ReLU(), torch.nn.Linear(self.hidden_size, self.hidden_size) ) def _prepare_dataset_embeddings( self, input_ids: torch.Tensor, dataset_embeddings: torch.Tensor, null_dataset_embedding: bool = False, ) -> torch.Tensor: if not isinstance(dataset_embeddings, torch.Tensor): dataset_embeddings = torch.tensor(dataset_embeddings) if len(dataset_embeddings.shape) == 2: # Auto-expand for a batch. dataset_embeddings = dataset_embeddings[None, :, :] # (b, d) -> (1, b, d) dataset_embeddings = dataset_embeddings.to(input_ids.device) batch_size = input_ids.shape[0] if (self.transductive_tokens_per_document > 1): if self.training: # Choose N random documents to fill our context window with. # This logic is a little confusing but allows us to sample a # different batch *per-document* assert dataset_embeddings.shape[1] == self.transductive_tokens_per_document R = torch.randint( low=0, high=len(dataset_embeddings), size=(batch_size, self.config.transductive_corpus_size), device=dataset_embeddings.device ) # TODO make this deterministic somehow for evaluation? dataset_embeddings = dataset_embeddings[R].reshape((batch_size, self.num_corpus_tokens, self.hidden_size)) else: dataset_embeddings = dataset_embeddings.reshape((1, self.num_corpus_tokens, self.hidden_size)) # print("reshaped to dataset_embeddings.shape =", dataset_embeddings.shape) if dataset_embeddings.shape[1] > self.num_corpus_tokens: # If too many dataset embeddings are passed in, just take the first N until # we have the proper number. dataset_embeddings = dataset_embeddings[:, :self.num_corpus_tokens, :] _, corpus_size, _hidden_size = dataset_embeddings.shape if _ == 1: # Auto-expand for a batch. dataset_embeddings = dataset_embeddings.expand((batch_size, -1, -1)) if self.training and self.sequence_dropout_prob > 0.0: sequence_dropout_mask = ( torch.rand((batch_size, corpus_size), device=dataset_embeddings.device) < self.sequence_dropout_prob ) null_embeddings = self.sequence_dropout_null_embedding[None, None].expand(batch_size, corpus_size, -1) dataset_embeddings = torch.where( sequence_dropout_mask[..., None], null_embeddings, dataset_embeddings ) elif null_dataset_embedding: null_embeddings = self.sequence_dropout_null_embedding[None, None].expand(batch_size, corpus_size, -1) dataset_embeddings = null_embeddings # print(f"[ContextualModelMixin] dataset_embeddings.shape = {dataset_embeddings.shape}") # backbone_max_seq_length = self.backbone.config.max_trained_positions # assert batch_size + (2 * self.n_soft_prompt + corpus_size) <= backbone_max_seq_length, "too many hard negatives for backbone model" soft_prompt = torch.ones((1, self.hidden_size), device=dataset_embeddings.device, dtype=dataset_embeddings.dtype) soft_prompt = self.prompt_projection(soft_prompt).reshape((1, self.n_soft_prompt, self.hidden_size)) soft_prompt = soft_prompt.expand((len(dataset_embeddings), -1, -1)) # -> (b, 4+b, d) # soft_prompt.repeat((len(input_ids), 1, 1)) soft_prompt = torch.cat((dataset_embeddings, soft_prompt), dim=1) # print(f"[ContextualModelMixin] soft_prompt.shape = {soft_prompt.shape}") if self.training and self.randomize_dataset_sequence_order: randomized_order = torch.stack( [ torch.cat( ( torch.randperm(corpus_size, device=soft_prompt.device), torch.arange(self.n_soft_prompt, device=soft_prompt.device) + corpus_size ), dim=0) for _ in range(batch_size)]) randomized_order = randomized_order.to(soft_prompt.device) soft_prompt = soft_prompt.gather(1, randomized_order[..., None].expand_as(soft_prompt)) return soft_prompt class BiEncoder(transformers.PreTrainedModel): embedder: transformers.PreTrainedModel def __init__( self, config, #: transformers.PreTrainedConfig, ): super().__init__(config=config) embedder, _ = load_embedder_and_tokenizer( config.embedder, ) if config.limit_layers: print0(f"Limiting layers to {config.limit_layers}") limit_layers(embedder, config.limit_layers) self.embedder = embedder # if ("t5" in embedder.config.model_type): # print0(f"using torch.compile() on embedder of type `{embedder.config.model_type}`") # self.embedder = torch.compile(self.embedder) self.hidden_size = self.embedder.config.hidden_size # Allow pooling to multiple tokens per document self.transductive_tokens_per_document = vars(self.config).get("transductive_tokens_per_document", 1) self.mlp = torch.nn.Sequential( torch.nn.Linear(self.hidden_size, self.hidden_size), torch.nn.GELU(), torch.nn.Linear(self.hidden_size, self.config.embedding_output_dim or self.hidden_size), ) self.temp = config.logit_scale if config.disable_dropout: disable_dropout(self) self.pooling_strategy = vars(config).get("pooling_strategy", "mean") def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, dataset_input_ids: Optional[torch.Tensor] = None, dataset_attention_mask: Optional[torch.Tensor] = None, token_type_ids = None, output_hidden_states: bool = False, ) -> torch.Tensor: """ query_embedding (float torch.Tensor) - shape (batch_size, embedding_dim) document_embeddings (float torch.Tensor) - shape (corpus_size, embedding_dim) where the corpus_size >= batch_size and is structured like this: [d1, d2, d3, hn1_1, hn1_2, hn2_1, hn2_2, hn3_1, hn3_2] for a corpus with three documents and two hard negatives per document """ # del dataset_input_ids # del dataset_attention_mask del token_type_ids # from cde.lib.dist import get_rank # tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased") # if get_rank() == 0: # breakpoint() # torch.distributed.barrier() outputs = ( self.embedder( input_ids=input_ids, attention_mask=attention_mask, ).last_hidden_state ) if self.transductive_tokens_per_document > 1: document_embeddings = None batch_size, seq_length, output_dim = outputs.shape if seq_length % self.transductive_tokens_per_document != 0: # Pad to nearest multiple n_extra_embeds = self.transductive_tokens_per_document - (seq_length % self.transductive_tokens_per_document) outputs = torch.cat( (outputs, torch.zeros((batch_size, n_extra_embeds, output_dim), device=outputs.device)), dim=1 ) attention_mask = torch.cat( (attention_mask, torch.zeros((batch_size, n_extra_embeds), device=attention_mask.device)), dim=1 ) seq_length += n_extra_embeds print(f"Added {n_extra_embeds} padding tokens to input_ids and attention_mask") # print("ftransductive_tokens_per_document {self.transductive_tokens_per_document} outputs.shape =", outputs.shape) outputs = outputs.reshape( (batch_size, self.transductive_tokens_per_document, seq_length // self.transductive_tokens_per_document, output_dim) ) attention_mask = attention_mask.reshape((batch_size, self.transductive_tokens_per_document, -1)) document_embeddings = mean_pool_3d(outputs, attention_mask) document_embeddings = document_embeddings.reshape((batch_size, self.transductive_tokens_per_document, output_dim)) else: if self.pooling_strategy == "mean": document_embeddings = mean_pool(outputs, attention_mask) else: document_embeddings = document_embeddings.max(dim=1) output = self.mlp(document_embeddings) if output_hidden_states: return { "hidden_states": outputs, "pooled": output, } else: return output class DatasetConditionedAutoregressive(transformers.PreTrainedModel, ContextualModelMixin): def __init__( self, config, dataset_backbone: transformers.PreTrainedModel, first_stage_hidden_size: int, ): super().__init__(config=config) self.backbone = dataset_backbone self.backbone_hidden_size = self.backbone.config.hidden_size self.hidden_size = first_stage_hidden_size # Input token size self.contextual_init() disable_causality(self.backbone) self.input_ln = torch.nn.LayerNorm( self.backbone_hidden_size, eps=1e-5 ) # Override contextual init self.output_projection = torch.nn.Sequential( torch.nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size), torch.nn.ReLU(), torch.nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size) ) self._shift_rotary_embedding() @property def num_corpus_tokens(self) -> int: return self.config.transductive_corpus_size * self.transductive_tokens_per_document @property def corpus_token_ratio(self) -> float: # How many tokens from the first stage make one token in the second # stage? return self.backbone_hidden_size / self.hidden_size def corpus_token_pad_size(self, n_tokens: int) -> int: return self.hidden_size % self.backbone_hidden_size def _shift_rotary_embedding(self) -> None: disable_transductive_rotary_embedding = vars(self.config).get("disable_transductive_rotary_embedding", True) # TODO: Can we do this for LLAMA? print("Warning: Positional embedding disabling not implemented for LLAMA.") def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, dataset_embeddings: torch.Tensor, output_hidden_states: bool = False, null_dataset_embedding: bool = False, ) -> torch.Tensor: soft_prompt = self._prepare_dataset_embeddings( input_ids=input_ids, dataset_embeddings=dataset_embeddings, null_dataset_embedding=null_dataset_embedding, ) # Reshape for this model. # print("[DatasetConditionedAutoregressive] 1 -> soft_prompt.shape =", soft_prompt.shape) num_soft_elements = torch.prod(torch.tensor(soft_prompt.shape[1:])).item() soft_prompt = soft_prompt.reshape((soft_prompt.shape[0], num_soft_elements)) num_padding_elements = self.backbone_hidden_size - (num_soft_elements % self.backbone_hidden_size) padding = torch.ones((soft_prompt.shape[0], num_padding_elements), device=soft_prompt.device) soft_prompt = torch.cat((soft_prompt, padding), dim=1) soft_prompt = soft_prompt.reshape( (soft_prompt.shape[0], -1, self.backbone_hidden_size) ) soft_prompt = self.input_ln(soft_prompt) # print("[DatasetConditionedAutoregressive] 2 -> soft_prompt.shape =", soft_prompt.shape) backbone_attention_mask = torch.ones( soft_prompt.shape[0:2], dtype=torch.long, device=soft_prompt.device, ) token_embeddings = self.backbone.get_input_embeddings() inputs_embeds = token_embeddings(input_ids) # (b, s) -> (b, s, d) # print("[2] inputs_embeds.shape =", inputs_embeds.shape) inputs_embeds = torch.cat((soft_prompt, inputs_embeds), dim=1) # (v, 4+b+s, d) # print("[3.a] inputs_embeds.shape =", inputs_embeds.shape) input_attention_mask = torch.cat((backbone_attention_mask, attention_mask), dim=1) # print("[3.b] attention_mask.shape =", attention_mask.shape) output = self.backbone( inputs_embeds=inputs_embeds, attention_mask=input_attention_mask, output_hidden_states=True, ) # (1, 4 + b + s, d) # trim soft prompt last_hidden_state = output.hidden_states[-1] n_soft_prompt_tokens = soft_prompt.shape[1] output_vectors = last_hidden_state[:, n_soft_prompt_tokens:, :] output_attention_mask = input_attention_mask[:, n_soft_prompt_tokens:] # Take last token position if vars(self.config).get("pooling_strategy") == "last_token": output_pooled = last_token_pool(output_vectors, output_attention_mask) elif vars(self.config).get("pooling_strategy") == "mean": output_pooled = mean_pool(output_vectors, output_attention_mask) else: output_pooled = mean_pool_weighted(output_vectors, output_attention_mask) # average with original vectors # TODO: Argparse for pooling strategy. output = self.output_projection(output_pooled) # (b, 2d) -> (b, d) if output_hidden_states: return { "hidden_states": output_vectors, "pooled": output, } else: return output class DatasetConditionedBiencoder(transformers.PreTrainedModel, ContextualModelMixin): def __init__( self, config, dataset_backbone: transformers.PreTrainedModel, ): super().__init__(config=config) self.backbone = dataset_backbone self.hidden_size = self.backbone.config.hidden_size self.hidden_size = dataset_backbone.config.hidden_size # self.input_ln = torch.nn.LayerNorm( # self.hidden_size, # eps=self.backbone.config.layer_norm_epsilon # ) self.contextual_init() self._shift_rotary_embedding() @property def num_corpus_tokens(self) -> int: return self.config.transductive_corpus_size * self.transductive_tokens_per_document def _shift_rotary_embedding(self) -> None: disable_transductive_rotary_embedding = vars(self.config).get("disable_transductive_rotary_embedding", True) if self.backbone.config.model_type.startswith("nomic") and disable_transductive_rotary_embedding: # We only want to apply positional embeddings to the # *text* portion of the backbone network. self.backbone.config.rotary_start_pos = 0.0 rotary_disabled = 0 rotary_start_pos = self.num_corpus_tokens for module in self.backbone.modules(): if hasattr(module, "rotary_emb_dim"): module.rotary_start_pos = rotary_start_pos rotary_disabled += 1 print0(f"modified {rotary_disabled} rotary modules – set rotary_start_pos to {rotary_start_pos}") def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, dataset_embeddings: torch.Tensor, output_hidden_states: bool = False, null_dataset_embedding: bool = False, ) -> torch.Tensor: # print(f"[DatasetConditionedBiencoder - 0] input_ids.shape => {input_ids.shape} // dataset_embeddings.shape =", dataset_embeddings.shape) soft_prompt = self._prepare_dataset_embeddings( input_ids=input_ids, dataset_embeddings=dataset_embeddings, null_dataset_embedding=null_dataset_embedding, ) # print(f"[DatasetConditionedBiencoder - 1] soft_prompt.shape => {soft_prompt.shape}") backbone_attention_mask = torch.ones( soft_prompt.shape[0:2], dtype=torch.long, device=soft_prompt.device, ) inputs_embeds = self.backbone.embeddings(input_ids) # (b, s) -> (b, s, d) # print("[2] inputs_embeds.shape =", inputs_embeds.shape) inputs_embeds = torch.cat((soft_prompt, inputs_embeds), dim=1) # (v, 4+b+s, d) # print("[3.a] inputs_embeds.shape =", inputs_embeds.shape) attention_mask = torch.cat((backbone_attention_mask, attention_mask), dim=1) # print("[3.b] attention_mask.shape =", attention_mask.shape) output = self.backbone( inputs_embeds=inputs_embeds, attention_mask=attention_mask, ) # (1, 4 + b + s, d) # trim soft prompt output_vectors = output.last_hidden_state # use only these tokens n_soft_prompt_tokens = soft_prompt.shape[1] # print("n_soft_prompt_tokens =", n_soft_prompt_tokens) output_vectors = output.last_hidden_state[:, n_soft_prompt_tokens:, :] output_attention_mask = attention_mask[:, n_soft_prompt_tokens:] # print("pooling output_vectors.shape =", output_vectors.shape, "and output_attention_mask.shape =", output_attention_mask.shape) output_pooled = mean_pool(output_vectors, output_attention_mask) # average with original vectors # TODO: Argparse for pooling strategy. # output_vectors = torch.cat((soft_prompt_pooled, output_pooled), dim=1) # (b, d) + (b, d) -> (b, 2d) # print("output_pooled.shape =", output_pooled.shape) output = self.output_projection(output_pooled) # (b, 2d) -> (b, d) # print("returning output.shape =", output.shape) if output_hidden_states: return { "hidden_states": output_vectors, "pooled": output, } else: return output class DatasetPrefixBiencoder(transformers.PreTrainedModel, ContextualModelMixin): def __init__( self, config, #: transformers.PreTrainedConfig, embedder: transformers.PreTrainedModel, ): super().__init__(config=config) self.embedder = embedder self.hidden_size = self.embedder.config.hidden_size self.contextual_init() def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, dataset_input_ids: torch.Tensor, dataset_attention_mask: torch.Tensor, output_hidden_states: bool = False, ) -> torch.Tensor: R = torch.randint(low=0, high=len(dataset_input_ids), size=(len(input_ids),), device=dataset_input_ids.device) dataset_input_ids = dataset_input_ids[R] input_ids = torch.cat((dataset_input_ids, input_ids), dim=1) dataset_attention_mask = torch.ones_like(dataset_attention_mask, device=dataset_attention_mask.device) input_attention_mask = torch.cat((dataset_attention_mask, attention_mask), dim=1) output_attention_mask = torch.cat( (torch.zeros_like(dataset_input_ids), attention_mask), dim=1 ) output = self.embedder( input_ids=input_ids, attention_mask=input_attention_mask, ) output_vectors = output.last_hidden_state output_pooled = mean_pool(output_vectors, output_attention_mask) output = self.output_projection(output_pooled) # (b, 2d) -> (b, d) if output_hidden_states: S_d = dataset_attention_mask.shape[1] output_vectors = output_vectors[:, S_d:, :] return { "hidden_states": output_vectors, "pooled": output, } else: return output class DatasetTransformer(transformers.PreTrainedModel): config_class = ContextualModelConfig embedder: transformers.PreTrainedModel dataset_backbone: transformers.PreTrainedModel def __init__( self, config, ): super().__init__(config=config) dataset_backbone, _ = load_embedder_and_tokenizer( vars(config).get("dataset_backbone", config.embedder) ) if config.limit_layers: print0(f"Limiting layers to {config.limit_layers}") limit_layers(dataset_backbone, config.limit_layers) biencoder_config = copy.deepcopy(config) biencoder_config.embedding_output_dim = None biencoder_config.limit_layers = vars(self.config).get("limit_layers_first_stage", None) self.first_stage_model = BiEncoder( config=biencoder_config, ) if vars(config).get("autoregressive_backbone", False): self.second_stage_model = DatasetConditionedAutoregressive( config=config, dataset_backbone=dataset_backbone, first_stage_hidden_size=self.first_stage_model.hidden_size, ) else: self.second_stage_model = DatasetConditionedBiencoder( config=config, dataset_backbone=dataset_backbone ) self.temp = config.logit_scale if config.disable_dropout: disable_dropout(self) transductive_tie_token_embeddings = vars(self.config).get("transductive_tie_token_embeddings", False) if transductive_tie_token_embeddings: self.second_stage_model.backbone.embeddings.word_embeddings.weight = ( self.first_stage_model.embedder.embeddings.word_embeddings.weight ) def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, dataset_input_ids: Optional[torch.Tensor], dataset_attention_mask: Optional[torch.Tensor], output_hidden_states: bool = False, ) -> torch.Tensor: """ input_ids (long torch.Tensor) – ids of input tokens attention_mask (bool torch.Tensor) """ dataset_embeddings = self.first_stage_model( input_ids=dataset_input_ids, attention_mask=dataset_attention_mask ) return self.second_stage_model( input_ids=input_ids, attention_mask=attention_mask, dataset_embeddings=dataset_embeddings, output_hidden_states=output_hidden_states, ) def get_model_class(name: str): if name in 'transductive': return DatasetTransformer elif name == 'biencoder': return BiEncoder elif name == "dataset_prefix_biencoder": return DatasetPrefixBiencoder else: raise ValueError(f'unknown model cls {name}')