jxm commited on
Commit
246be81
1 Parent(s): da4e4d1

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +412 -6
model.py CHANGED
@@ -1,14 +1,419 @@
1
- from typing import Dict, Optional, Union
2
 
3
  import copy
 
 
 
 
4
  import torch
5
  import torch.nn as nn
6
  import transformers
7
 
8
- from cde.lib.dist import print0
9
- from cde.lib.tensor import mean_pool, mean_pool_3d, mean_pool_weighted, last_token_pool
10
 
11
- from cde.lib import load_embedder_and_tokenizer, ContextualModelConfig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  def limit_layers(model: transformers.PreTrainedModel, n_layers: int) -> None:
@@ -25,6 +430,7 @@ def limit_layers(model: transformers.PreTrainedModel, n_layers: int) -> None:
25
  model.encoder.layer = model.encoder.layer[:n_layers]
26
  else:
27
  raise RuntimeError(f"unknown how to limit layers of model {type(model)}")
 
28
 
29
 
30
  def disable_dropout(model: torch.nn.Module):
@@ -413,7 +819,7 @@ class DatasetConditionedBiencoder(transformers.PreTrainedModel, ContextualModelM
413
  if hasattr(module, "rotary_emb_dim"):
414
  module.rotary_start_pos = rotary_start_pos
415
  rotary_disabled += 1
416
- print0(f"modified {rotary_disabled} rotary modules – set rotary_start_pos to {rotary_start_pos}")
417
 
418
  def forward(
419
  self,
@@ -580,7 +986,7 @@ class DatasetTransformer(transformers.PreTrainedModel):
580
  output_hidden_states: bool = False,
581
  ) -> torch.Tensor:
582
  """
583
- input_ids (long torch.Tensor) – ids of input tokens
584
  attention_mask (bool torch.Tensor)
585
  """
586
  dataset_embeddings = self.first_stage_model(
 
1
+ from typing import Callable, Dict, Optional, Union, Tuple
2
 
3
  import copy
4
+ import math
5
+ import multiprocessing
6
+ import os
7
+
8
  import torch
9
  import torch.nn as nn
10
  import transformers
11
 
 
 
12
 
13
+ class ContextualModelConfig(transformers.configuration_utils.PretrainedConfig):
14
+ """We create a dummy configuration class that will just set properties
15
+ based on whatever kwargs we pass in.
16
+
17
+ When this class is initialized (see experiments.py) we pass in the
18
+ union of all data, model, and training args, all of which should
19
+ get saved to the config json.
20
+ """
21
+
22
+ def __init__(self, **kwargs):
23
+ for key, value in kwargs.items():
24
+ try:
25
+ json.dumps(value)
26
+ setattr(self, key, value)
27
+ except TypeError:
28
+ # value was not JSON-serializable, skip
29
+ continue
30
+ super().__init__()
31
+
32
+
33
+ def load_embedder_and_tokenizer(name: str) -> Tuple[
34
+ transformers.PreTrainedModel,
35
+ transformers.PreTrainedTokenizer
36
+ ]:
37
+ if name.startswith("nomic") or (name == "bert-base-uncased"):
38
+ from cde.lib.nomic_bert import NomicBertModel
39
+ if name.endswith("--from-scratch"):
40
+ name = name.replace("--from-scratch", "")
41
+ config = transformers.AutoConfig.from_pretrained(name, trust_remote_code=True)
42
+ model = NomicBertModel._from_config(config)
43
+ else:
44
+ model = NomicBertModel.from_pretrained(
45
+ name, add_pooling_layer=False
46
+ )
47
+ tokenizer = transformers.AutoTokenizer.from_pretrained(name)
48
+ elif name in ["gtr-base", "gtr_base"]:
49
+ model = transformers.AutoModel.from_pretrained(
50
+ "sentence-transformers/gtr-t5-base"
51
+ ).encoder
52
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
53
+ "sentence-transformers/gtr-t5-base"
54
+ )
55
+ elif name == "pile-t5-base-encoder":
56
+ model = transformers.AutoModel.from_pretrained(
57
+ "EleutherAI/pile-t5-base"
58
+ ).encoder
59
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
60
+ "EleutherAI/pile-t5-base"
61
+ )
62
+ tokenizer.pad_token = tokenizer.eos_token
63
+ elif name == "pile-t5-base-decoder":
64
+ model = transformers.AutoModel.from_pretrained(
65
+ "EleutherAI/pile-t5-base"
66
+ ).decoder
67
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
68
+ "EleutherAI/pile-t5-base"
69
+ )
70
+ tokenizer.pad_token = tokenizer.eos_token
71
+ elif name.startswith("gpt2") or name.startswith("meta-llama") or ("Llama" in name):
72
+ model = transformers.AutoModelForCausalLM.from_pretrained(
73
+ name,
74
+ # torch_dtype=torch.bfloat16,
75
+ attn_implementation="flash_attention_2",
76
+ low_cpu_mem_usage=True,
77
+ # device_map="auto",
78
+ )
79
+ model.padding_side = "right"
80
+ tokenizer = transformers.AutoTokenizer.from_pretrained(name)
81
+ tokenizer.pad_token = tokenizer.eos_token
82
+ tokenizer.add_eos_token = True
83
+ else:
84
+ model = transformers.AutoModel.from_pretrained(name, trust_remote_code=True)
85
+ tokenizer = transformers.AutoTokenizer.from_pretrained(name)
86
+
87
+ # if use_bettertransformer:
88
+ # from optimum.bettertransformer import BetterTransformer
89
+ # model = BetterTransformer.transform(model)
90
+ return model, tokenizer
91
+ def get_world_size() -> int:
92
+ try:
93
+ return torch.distributed.get_world_size()
94
+ except (RuntimeError, ValueError):
95
+ return 1
96
+
97
+
98
+ def get_rank() -> int:
99
+ try:
100
+ return torch.distributed.get_rank()
101
+ except (RuntimeError, ValueError):
102
+ return 0
103
+
104
+ def gather(t: torch.Tensor) -> torch.Tensor:
105
+ # torch.distributed.nn.all_gather scales by world size since the reduce op is SUM
106
+ # https://github.com/pytorch/pytorch/issues/58005
107
+ # only should use torch.distributed.nn.all_gather if we implement a `local_loss`
108
+ # like: https://github.com/mlfoundations/open_clip/issues/616
109
+ world_size = get_world_size()
110
+ if world_size == 1:
111
+ return t
112
+
113
+ if t.ndim == 0:
114
+ t = t.unsqueeze(0)
115
+
116
+ gathered = [torch.empty_like(t) for _ in range(world_size)]
117
+ torch.distributed.all_gather(gathered, t)
118
+ gathered[get_rank()] = t
119
+ return torch.cat(gathered, dim=0)
120
+
121
+
122
+ def gather_sum(t: torch.Tensor) -> torch.Tensor:
123
+ # torch.distributed.nn.all_gather scales by world size since the reduce op is SUM
124
+ # https://github.com/pytorch/pytorch/issues/58005
125
+ # only should use torch.distributed.nn.all_gather if we implement a `local_loss`
126
+ # like: https://github.com/mlfoundations/open_clip/issues/616
127
+ world_size = get_world_size()
128
+ if world_size == 1:
129
+ return t
130
+
131
+ if t.ndim == 0:
132
+ t = t.unsqueeze(0)
133
+
134
+ gathered = [torch.empty_like(t) for _ in range(world_size)]
135
+ torch.distributed.all_gather(gathered, t)
136
+ gathered = torch.stack(gathered, dim=0)
137
+ return gathered.sum(dim=0) # Sum across workers
138
+
139
+
140
+ def get_num_proc() -> int:
141
+ world_size: int = get_world_size()
142
+ try:
143
+ # os.sched_getaffinity respects schedulers, unlike cpu_count(), but it's only available
144
+ # on some Unix platforms, so we support both!
145
+ return len(os.sched_getaffinity(0)) // world_size # type: ignore[attr-defined]
146
+ except AttributeError:
147
+ return multiprocessing.cpu_count() // world_size
148
+
149
+
150
+ def torch_main_worker_finish_first(func: Callable):
151
+ def wrapper(*args, **kwargs):
152
+ # Get local rank (need to support non-DDP).
153
+ try:
154
+ local_rank = torch.distributed.get_rank()
155
+ ddp_enabled = True
156
+ except (RuntimeError, ValueError):
157
+ local_rank = -1
158
+ ddp_enabled = False
159
+ is_main_worker = local_rank <= 0
160
+ # Run on main worker first.
161
+ if is_main_worker:
162
+ result = func(*args, **kwargs)
163
+ # Then everyone waits.
164
+ if ddp_enabled:
165
+ torch.distributed.barrier()
166
+ # Run on other workers now.
167
+ if not is_main_worker:
168
+ result = func(*args, **kwargs)
169
+ # Now everyone waits again.
170
+ if ddp_enabled:
171
+ torch.distributed.barrier()
172
+ return result
173
+
174
+ return wrapper
175
+
176
+
177
+ def print0(*args, **kwargs) -> None:
178
+ if get_rank() == 0:
179
+ print(*args, **kwargs)
180
+
181
+
182
+ def verify_ddp_weights_equal(model: torch.nn.Module, atol: float = 1e-5) -> None:
183
+ if hasattr(model, "module"):
184
+ model = model.module
185
+
186
+ world_size = get_world_size()
187
+
188
+ if world_size > 8:
189
+ print0(f"[verify_ddp_weights_equal] Skipping with world_size={world_size} ⚠️")
190
+ return
191
+
192
+ for name, param in model.named_parameters():
193
+ if param is None: continue
194
+ if param.grad is None:
195
+ print0(f"[verify_ddp_weights_equal] Skipping param [{name}] with no grad")
196
+ continue
197
+ gathered_param = gather(param).reshape((world_size, -1))
198
+ absolute_diffs = (gathered_param[None, 0, :] - gathered_param).abs()
199
+ rank_params_eq = (absolute_diffs < atol).all()
200
+ assert rank_params_eq, f"❌ param [{name}] not equal - got max_absolute_diff={absolute_diffs.max()}"
201
+ ###################################################################################################################
202
+ gathered_param_grad = gather(param.grad).reshape((world_size, -1))
203
+ absolute_grad_diffs = (gathered_param_grad[None, 0, :] - gathered_param_grad).abs()
204
+ rank_grad_params_eq = (absolute_grad_diffs < atol).all()
205
+ assert rank_grad_params_eq, f"❌ param [{name}] grad not equal - got max_absolute_diff={absolute_grad_diffs.max()}"
206
+ ###################################################################################################################
207
+
208
+
209
+ print0("[verify_ddp_weights_equal] Verified DDP parameter correctness ✅")
210
+
211
+
212
+
213
+ def mean_pool_3d(
214
+ hidden_states: torch.Tensor, attention_mask: torch.Tensor
215
+ ) -> torch.Tensor:
216
+ B, T, S, D = hidden_states.shape
217
+ unmasked_outputs = hidden_states * attention_mask[..., None]
218
+ pooled_outputs = unmasked_outputs.sum(dim=2) / (attention_mask.sum(dim=2)[..., None] + 1e-9)
219
+
220
+ # fix for gradient flow: fill empty rows with the mean of the rest of the sequence
221
+ sequence_means = (
222
+ hidden_states.reshape((B, S * T, D))
223
+ .mean(dim=1, keepdim=True)
224
+ .expand(-1, T, -1)
225
+ )
226
+ pooled_outputs = pooled_outputs.where(
227
+ (attention_mask.sum(dim=2)[..., None] > 0),
228
+ sequence_means
229
+ )
230
+ assert pooled_outputs.shape == (B, T, D)
231
+
232
+ return pooled_outputs
233
+
234
+ def mean_pool(
235
+ hidden_states: torch.Tensor, attention_mask: torch.Tensor
236
+ ) -> torch.Tensor:
237
+ B, _S, D = hidden_states.shape
238
+ unmasked_outputs = hidden_states * attention_mask[..., None]
239
+ pooled_outputs = unmasked_outputs.sum(dim=1) / (attention_mask.sum(dim=1)[:, None] + 1e-20)
240
+
241
+ assert pooled_outputs.shape == (B, D)
242
+ return pooled_outputs
243
+
244
+
245
+ def mean_pool_weighted(
246
+ hidden_states: torch.Tensor, attention_mask: torch.Tensor
247
+ ) -> torch.Tensor:
248
+ B, _S, D = hidden_states.shape
249
+ attention_mask *= attention_mask.cumsum(dim=1) # [0,1,1,1,0,0] -> [0,1,2,3,0,0]
250
+ s = torch.sum(hidden_states * attention_mask.unsqueeze(-1).float(), dim=1)
251
+ d = attention_mask.sum(dim=1, keepdim=True).float()
252
+ return s / d
253
+
254
+
255
+ def slice_sparse_tensor_rows(t: torch.sparse.Tensor, min_row: int, max_row: int) -> torch.sparse.Tensor:
256
+ assert min_row < max_row, f"can't slice from row {min_row} to {max_row}"
257
+ t = t.coalesce()
258
+ row_idxs = t.indices()[0]
259
+ index_mask = (min_row <= row_idxs) & (row_idxs < max_row)
260
+
261
+ num_rows = (max_row - min_row)
262
+ num_cols = t.shape[1]
263
+
264
+ idxs = t.indices()[:, index_mask]
265
+ vals = t.values()[index_mask]
266
+ return torch.sparse_coo_tensor(idxs, vals, size=(num_rows, num_cols)).coalesce()
267
+
268
+
269
+ def slice_tensor_rows(t: torch.Tensor, min_row: int, max_row: int) -> torch.Tensor:
270
+ if t.is_sparse:
271
+ return slice_sparse_tensor_rows(t=t, min_row=min_row, max_row=max_row)
272
+ else:
273
+ return t[min_row:max_row]
274
+
275
+
276
+ @torch.no_grad
277
+ def maxsim(
278
+ X: torch.Tensor, y: torch.Tensor,
279
+ maximize: bool, chunk_size: int = 8_000,
280
+ debug_mem_usage: bool = False) -> torch.Tensor:
281
+ device = X.device
282
+ n_samples = X.shape[0]
283
+
284
+ max_sim_v = torch.zeros(n_samples, device=device, dtype=X.dtype)
285
+ max_sim_i = torch.zeros(n_samples, device=device, dtype=torch.int64)
286
+
287
+ # TODO: Implement faster max (without going to dense tensors).
288
+ # TODO: Use multiple GPUs.
289
+ rank = get_rank()
290
+ world_size = get_world_size()
291
+
292
+ worker_worklist_size = int(math.ceil(n_samples / world_size))
293
+ splits_start_idx = worker_worklist_size * rank
294
+ splits_end_idx = worker_worklist_size * (rank + 1)
295
+
296
+ for i in range(splits_start_idx, splits_end_idx, chunk_size):
297
+ start, end = i, min(i + chunk_size, n_samples)
298
+ sub_x = slice_tensor_rows(X, start, end)
299
+ if debug_mem_usage: print(f"[maxsim] step {i} cuda mem free/total = {torch.cuda.mem_get_info()}")
300
+ if debug_mem_usage: print("[maxsim] sub_x.shape:", sub_x.shape, "//", "y.shape:", y.shape)
301
+ sub_sim = sub_x @ y # TODO – Implement sparse max here to save mem!
302
+ sub_sim = sub_sim
303
+ if maximize:
304
+ sub_max_sim_v, sub_max_sim_i = sub_sim.to_dense().max(dim=-1)
305
+ else:
306
+ sub_max_sim_v, sub_max_sim_i = sub_sim.to_dense().min(dim=-1)
307
+ del sub_sim
308
+ del sub_x
309
+ torch.cuda.empty_cache() # needs to happen after maxsim for some reason.
310
+ max_sim_v[start: end] = sub_max_sim_v
311
+ max_sim_i[start: end] = sub_max_sim_i
312
+
313
+ # gather
314
+ max_sim_v = gather_sum(max_sim_v)
315
+ max_sim_i = gather_sum(max_sim_i)
316
+ k = y.shape[1]
317
+
318
+ assert max_sim_v.shape == (n_samples,)
319
+ assert max_sim_i.shape == (n_samples,)
320
+ assert max_sim_i.min() >= 0
321
+ assert max_sim_i.max() <= k
322
+
323
+ return max_sim_v, max_sim_i
324
+
325
+
326
+ def forward_batched(
327
+ model: torch.nn.Module,
328
+ input_ids: torch.Tensor,
329
+ attention_mask: torch.Tensor,
330
+ batch_size: int,
331
+ dataset_input_ids: Optional[torch.Tensor] = None,
332
+ dataset_attention_mask: Optional[torch.Tensor] = None,
333
+ **second_stage_model_kwargs,
334
+ ) -> torch.Tensor:
335
+ if hasattr(model, "module"):
336
+ model = model.module
337
+
338
+ if hasattr(model, "first_stage_model"):
339
+ # Support pooling over 3D dataset_input_ids inputs.
340
+ if len(dataset_input_ids.shape) == 2:
341
+ dataset_input_ids = dataset_input_ids[None]
342
+ dataset_attention_mask = dataset_attention_mask[None]
343
+
344
+ dataset_embeddings = []
345
+ for j in range(len(dataset_input_ids)):
346
+ i = 0
347
+ dataset_embeddings_batch = []
348
+ while i < dataset_input_ids.shape[1]:
349
+ dataset_embeddings_batch.append(
350
+ model.first_stage_model(
351
+ input_ids=dataset_input_ids[j][i:i+batch_size],
352
+ attention_mask=dataset_attention_mask[j][i:i+batch_size],
353
+ )
354
+ )
355
+ i += batch_size
356
+ dataset_embeddings.append(
357
+ torch.cat(dataset_embeddings_batch, dim=0)
358
+ )
359
+
360
+ # Automatically pool over 3D dataset_input_ids.
361
+ dataset_embeddings = torch.stack(dataset_embeddings, dim=0).mean(dim=0)
362
+
363
+ j = 0
364
+ outputs = []
365
+ while j < len(input_ids):
366
+ outputs.append(
367
+ model.second_stage_model(
368
+ input_ids=input_ids[j:j+batch_size],
369
+ attention_mask=attention_mask[j:j+batch_size],
370
+ dataset_embeddings=dataset_embeddings,
371
+ **second_stage_model_kwargs,
372
+ )
373
+ )
374
+ j += batch_size
375
+ return torch.cat(outputs, dim=0)
376
+
377
+ else:
378
+ i = 0
379
+ outputs = []
380
+ while i < len(input_ids):
381
+ # breakpoint()
382
+ outputs.append(
383
+ model(
384
+ input_ids=input_ids[i:i+batch_size],
385
+ attention_mask=attention_mask[i:i+batch_size],
386
+ **second_stage_model_kwargs,
387
+ )
388
+ )
389
+ i += batch_size
390
+ return torch.cat(outputs, dim=0)
391
+
392
+
393
+ def last_token_pool(hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
394
+ # https://github.com/ContextualAI/gritlm/blob/main/gritlm/gritlm.py#L190
395
+ b, n, d = hidden_state.size()
396
+ # Get the last `1` in the attention mask of each item
397
+ # Often it is just `gather_indices = torch.argmin(attention_mask, 1, keepdim=False) - 1`
398
+ # except when 1) There's all 1's 2) There's 0's before the 1's
399
+ reversed_mask = torch.flip(attention_mask, dims=(1,))
400
+ argmax_reverse = torch.argmax(reversed_mask, dim=1, keepdim=False)
401
+ gather_indices = attention_mask.size(1) - argmax_reverse - 1
402
+ # If there are empty sequences, where the index would become -1 it will crash so set them to 0
403
+ gather_indices = torch.clamp(gather_indices, min=0)
404
+ # Turn indices from shape [b] -> [b, 1, d]
405
+ gather_indices = gather_indices.unsqueeze(-1).repeat(1, d)
406
+ gather_indices = gather_indices.unsqueeze(1)
407
+ assert gather_indices.shape == (b, 1, d)
408
+ # Gather along the seq len: [b, n, d] -> [b, d]
409
+ # Actually no need for the attention mask as we gather the last token where attn_mask=1 but
410
+ # as some indices (which shouldn't be attended to) may be 0 due to clamp, use mask to ignore them again
411
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand((b, n, d)).float()
412
+ return torch.gather(hidden_state * input_mask_expanded, 1, gather_indices).squeeze(dim=1)
413
+
414
+ def print0(*args, **kwargs) -> None:
415
+ if get_rank() == 0:
416
+ print(*args, **kwargs)
417
 
418
 
419
  def limit_layers(model: transformers.PreTrainedModel, n_layers: int) -> None:
 
430
  model.encoder.layer = model.encoder.layer[:n_layers]
431
  else:
432
  raise RuntimeError(f"unknown how to limit layers of model {type(model)}")
433
+
434
 
435
 
436
  def disable_dropout(model: torch.nn.Module):
 
819
  if hasattr(module, "rotary_emb_dim"):
820
  module.rotary_start_pos = rotary_start_pos
821
  rotary_disabled += 1
822
+ print0(f"modified {rotary_disabled} rotary modules – set rotary_start_pos to {rotary_start_pos}")
823
 
824
  def forward(
825
  self,
 
986
  output_hidden_states: bool = False,
987
  ) -> torch.Tensor:
988
  """
989
+ input_ids (long torch.Tensor) – ids of input tokens
990
  attention_mask (bool torch.Tensor)
991
  """
992
  dataset_embeddings = self.first_stage_model(