|
import torch |
|
import numpy as np |
|
import torch.distributed as dist |
|
from transformers.utils import logging |
|
from typing import List, Tuple, Optional |
|
from .modeling_retrieval import BM25Retriever |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class Memory(torch.nn.Module): |
|
def __init__(self, model_config, beacon_window:int=1024, beacon_stride:List[int]=[512], beacon_attn:str="step-expansion", beacon_attend_previous:bool=True, beacon_ratio:List[int]=[8], beacon_stride_mix:str="step-random", beacon_ratio_mix:str="step-random", beacon_param:List[str]=["q", "k", "v", "o"], k_seq_dim:int=2, v_seq_dim:int=2, retrieval_method:str=None, retrieval_topk:int=2) -> None: |
|
super().__init__() |
|
|
|
for stride in beacon_stride: |
|
assert beacon_window >= stride, f"Make sure the beacon_window {beacon_window} >= beacon_stride {stride}!" |
|
assert beacon_attn in ["segmentation", "step-expansion", "full-coverage"], f"beacon_attn {beacon_attn} not implemented!" |
|
assert beacon_stride_mix in ["instance-random", "step-random", "mix-random"], f"beacon_stride_mix {beacon_stride_mix} not implemented!" |
|
assert beacon_ratio_mix in ["instance-random", "step-random", "mix-random", "sequence"] or "adapt-" in beacon_ratio_mix, f"beacon_ratio_mix {beacon_ratio_mix} not implemented!" |
|
|
|
if retrieval_method == "bm25": |
|
assert len(beacon_stride) == 1, f"Currently retrieval do not support dynamic strides." |
|
assert retrieval_topk >= 2, f"Make sure retrieval_topk >= 2. Found {retrieval_topk}." |
|
assert len(beacon_ratio) == 2, f"Make sure there are two beacon ratios specified, one for retrieved windows and the other for non-retrieved windows. Found {self.beacon_ratio}" |
|
|
|
info = f"applying activation beacon on {beacon_param}, with window size {beacon_window}, stride {beacon_stride} (mixed by {beacon_stride_mix}), {beacon_attn} attention ({'attending to previous beacons' if beacon_attend_previous else 'not attending to previous beacons'}), condensing ratio {beacon_ratio} (mixed by {beacon_ratio_mix}), {retrieval_method+' retrieval'+' top-'+str(retrieval_topk) if retrieval_method is not None else 'no retrieval'}, ..." |
|
logger.info(info) |
|
|
|
self.beacon_window = beacon_window |
|
self.beacon_stride = beacon_stride |
|
self.beacon_attn = beacon_attn |
|
self.beacon_ratio = beacon_ratio |
|
self.beacon_stride_mix = beacon_stride_mix |
|
self.beacon_ratio_mix = beacon_ratio_mix |
|
max_beacon_size = max([beacon_window // x for x in beacon_ratio if x > 0] + [1]) |
|
self.beacon_tokens = torch.zeros(max_beacon_size, dtype=torch.long) + model_config.vocab_size |
|
|
|
|
|
self.k_seq_dim = k_seq_dim |
|
self.v_seq_dim = v_seq_dim |
|
self.num_layers = model_config.num_hidden_layers |
|
self.max_position_embeddings = model_config.max_position_embeddings |
|
|
|
self.retrieval_method = retrieval_method |
|
self.retrieval_topk = retrieval_topk |
|
|
|
self.rng = np.random.default_rng(42) |
|
self.reset() |
|
|
|
@property |
|
def finish(self): |
|
return self.end_idx == self.sequence_length |
|
|
|
def get_memory_size(self): |
|
beacon_memory_size = 0 |
|
raw_memory_size = 0 |
|
if self.beacon_activations[0][0] is not None: |
|
beacon_memory_size += self.beacon_activations[0][0].shape[self.k_seq_dim] |
|
if self.raw_activations[0][0] is not None: |
|
raw_memory_size += self.raw_activations[0][0].shape[self.k_seq_dim] |
|
memory_size = beacon_memory_size + raw_memory_size |
|
return beacon_memory_size, raw_memory_size, memory_size |
|
|
|
def reset(self): |
|
|
|
self.sequence_length = 0 |
|
|
|
self.total_sequence_length = 0 |
|
|
|
self.start_idx = 0 |
|
|
|
self.end_idx = 0 |
|
|
|
self._beacon_sizes = [] |
|
|
|
self.step_idx = 0 |
|
|
|
if self.beacon_ratio_mix != "step-random": |
|
self._stride = None |
|
self._ratio = None |
|
|
|
self.batch_loss = None |
|
self.valid_token_num = None |
|
|
|
self.raw_activations = [(None, None) for _ in range(self.num_layers)] |
|
self.beacon_activations = [(None, None) for _ in range(self.num_layers)] |
|
|
|
if self.retrieval_method == "bm25": |
|
self.retriever = BM25Retriever() |
|
self.beacon_ratio_mix = "retrieval" |
|
|
|
|
|
if self.training and dist.is_initialized(): |
|
rng_state = self.rng.__getstate__() |
|
if dist.get_rank() == 0: |
|
obj = [rng_state] |
|
else: |
|
obj = [None] |
|
dist.broadcast_object_list(obj, src=0) |
|
self.rng.__setstate__(obj[0]) |
|
|
|
def prepare(self, input_ids, attention_mask, labels): |
|
""" |
|
Prepare inputs for the model. |
|
""" |
|
|
|
assert input_ids.shape[0] == 1, f"Make sure batch_size is 1!" |
|
|
|
|
|
self.start_idx -= self.sequence_length |
|
self.end_idx -= self.sequence_length |
|
|
|
self.sequence_length = input_ids.shape[1] |
|
self.total_sequence_length += input_ids.shape[1] |
|
|
|
if labels is not None: |
|
|
|
labels = torch.cat([labels[:, 1:], labels.new_zeros((labels.shape[0], 1)) - 100], dim=-1) |
|
|
|
|
|
self.input_ids = input_ids |
|
self.attention_mask = attention_mask |
|
self.labels = labels |
|
|
|
|
|
|
|
|
|
if self.retrieval_method == "bm25" and self.step_idx == 0: |
|
index = BM25Retriever() |
|
window = self.beacon_window |
|
stride = self.beacon_stride[0] |
|
|
|
input_ids = input_ids[0].tolist() |
|
corpus = [input_ids[:window]] |
|
for j in range(window, len(input_ids), stride): |
|
corpus.append(input_ids[j: j + stride]) |
|
|
|
|
|
query = input_ids[-32:] |
|
index.index(corpus) |
|
topk_scores, topk_indices = index.search(query, hits=self.retrieval_topk) |
|
topk_indices = set([x for x in topk_indices[0] if x > -1]) |
|
|
|
self._topk_indices = topk_indices |
|
|
|
def set_stride(self): |
|
"""Choose a stride from self.beacon_stride""" |
|
beacon_stride = self.beacon_stride |
|
|
|
if len(beacon_stride) == 1: |
|
return beacon_stride[0] |
|
|
|
if self.beacon_stride_mix == "mix-random": |
|
stride_mix = self.rng.choice(["instance-random", "step-random"]).tolist() |
|
else: |
|
stride_mix = self.beacon_stride_mix |
|
|
|
if stride_mix == "instance-random": |
|
if self._beacon_stride is None: |
|
stride = self.rng.choice(beacon_stride).tolist() |
|
self._stride = stride |
|
else: |
|
stride = self._stride |
|
|
|
elif stride_mix == "step-random": |
|
stride = self.rng.choice(beacon_stride).tolist() |
|
|
|
else: |
|
raise NotImplementedError |
|
|
|
return stride |
|
|
|
def set_condensing_ratio(self, beacon_stride, start_idx, end_idx): |
|
"""Choose a condensing ratio from self.beacon_ratio""" |
|
def filter_ratio(ratios, stride): |
|
valid_ratios = [] |
|
for ratio in ratios: |
|
|
|
if stride < ratio: |
|
continue |
|
|
|
if self.beacon_attn != "full-coverage" and ratio > 0 and (stride % ratio) != 0: |
|
continue |
|
|
|
if ratio == 0 and self.training: |
|
previous_beacons = [b for b in self._beacon_sizes if b != -1] |
|
following_beacons = (start_idx + stride + self.beacon_window) <= self.sequence_length |
|
if len(previous_beacons) == 0 and not following_beacons: |
|
continue |
|
valid_ratios.append(ratio) |
|
assert len(valid_ratios), f"Cannot find valid condensing ratio (among {ratios}) for stride {stride}!" |
|
return valid_ratios |
|
|
|
def get_max_length(ratios): |
|
max_lengths = [] |
|
for condensing_ratio in ratios: |
|
if condensing_ratio > 0: |
|
max_lengths.append((self.max_position_embeddings - self.beacon_window) * condensing_ratio + self.beacon_window) |
|
else: |
|
max_lengths.append(self.max_position_embeddings) |
|
return max_lengths |
|
|
|
if len(self.beacon_ratio) == 1: |
|
return self.beacon_ratio[0] |
|
|
|
beacon_ratio = filter_ratio(self.beacon_ratio, beacon_stride) |
|
|
|
if self.beacon_ratio_mix == "mix-random": |
|
ratio_mix = self.rng.choice(["instance-random", "step-random"]).tolist() |
|
else: |
|
ratio_mix = self.beacon_ratio_mix |
|
|
|
if ratio_mix == "instance-random": |
|
if self._ratio is None: |
|
beacon_ratio = self.rng.choice(beacon_ratio).tolist() |
|
self._ratio = beacon_ratio |
|
else: |
|
beacon_ratio = self._ratio |
|
|
|
elif ratio_mix == "step-random": |
|
beacon_ratio = self.rng.choice(beacon_ratio).tolist() |
|
|
|
elif ratio_mix == "sequence": |
|
idx = min(self.step_idx, len(beacon_ratio) - 1) |
|
beacon_ratio = beacon_ratio[idx] |
|
|
|
elif ratio_mix == "retrieval": |
|
|
|
if self.step_idx in self._topk_indices: |
|
beacon_ratio = min(self.beacon_ratio) |
|
else: |
|
beacon_ratio = max(self.beacon_ratio) |
|
|
|
elif "adapt" in ratio_mix: |
|
if self._ratio is None: |
|
future_length = int(ratio_mix.split("-")[1]) |
|
sequence_length = self.total_sequence_length + future_length |
|
max_lengths = get_max_length(beacon_ratio) |
|
|
|
valid_max_lengths_and_indices = [x for x in enumerate(max_lengths) if x[1] >= sequence_length] |
|
if len(valid_max_lengths_and_indices): |
|
minimum_length_index = min(valid_max_lengths_and_indices, key=lambda x: x[1])[0] |
|
|
|
beacon_ratio = beacon_ratio[minimum_length_index] |
|
else: |
|
beacon_ratio = max(beacon_ratio) |
|
|
|
self._ratio = beacon_ratio |
|
else: |
|
beacon_ratio = self._ratio |
|
|
|
return beacon_ratio |
|
|
|
def step(self): |
|
""" |
|
Yield one window with the following logic: |
|
|
|
The window size is L, the stride is S. |
|
The window moves over S tokens at a time. The raw activations passed by the window are condensed according to a condensing_ratio. |
|
The beacons are added if and only if the raw activations fulfill the window. |
|
In the future, we may switch window size to decrease cache size of raw activations. |
|
""" |
|
|
|
start_idx = self.start_idx |
|
|
|
end_idx = start_idx + self.beacon_window |
|
|
|
|
|
|
|
is_full_window = True |
|
if end_idx > self.sequence_length: |
|
|
|
end_idx = self.sequence_length |
|
is_full_window = False |
|
|
|
|
|
window_size = end_idx - start_idx |
|
|
|
if is_full_window: |
|
|
|
beacon_stride = self.set_stride() |
|
condensing_ratio = self.set_condensing_ratio(beacon_stride, start_idx=start_idx, end_idx=end_idx) |
|
|
|
|
|
if condensing_ratio > 0: |
|
beacon_size = beacon_stride // condensing_ratio |
|
else: |
|
|
|
beacon_size = -1 |
|
|
|
next_start_idx = start_idx + beacon_stride |
|
|
|
raw_size_to_cache = end_idx - next_start_idx |
|
self.remaining_size = 0 |
|
else: |
|
|
|
next_start_idx = start_idx |
|
|
|
raw_size_to_cache = window_size |
|
self.remaining_size = window_size |
|
beacon_size = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_ids = self.input_ids[:, self.end_idx: end_idx] |
|
batch_size = input_ids.shape[0] |
|
if self.attention_mask is not None: |
|
attention_mask = self.attention_mask[:, self.end_idx: end_idx] |
|
else: |
|
attention_mask = torch.ones_like(input_ids) |
|
if self.labels is not None: |
|
labels = self.labels[:, self.end_idx: end_idx] |
|
else: |
|
labels = None |
|
|
|
_, _, memory_size = self.get_memory_size() |
|
if memory_size > 0: |
|
attention_mask = torch.cat([attention_mask.new_ones(batch_size, memory_size), attention_mask], dim=1) |
|
|
|
|
|
if is_full_window and beacon_size > 0: |
|
input_ids = torch.cat([input_ids, self.beacon_tokens[:beacon_size].expand(batch_size, -1).to(input_ids.device, dtype=input_ids.dtype)], dim=1) |
|
|
|
attention_mask = torch.cat([attention_mask, attention_mask.new_ones(batch_size, beacon_size)], dim=1) |
|
if labels is not None: |
|
labels = torch.cat([labels, labels.new_zeros(batch_size, beacon_size) - 100], dim=1) |
|
|
|
|
|
past_key_values = [] |
|
for (beacon_key, beacon_value), (raw_key, raw_value) in zip(self.beacon_activations, self.raw_activations): |
|
key = cat_tensor([beacon_key, raw_key], dim=self.k_seq_dim) |
|
value = cat_tensor([beacon_value, raw_value], dim=self.v_seq_dim) |
|
layer_past_key_values = (key, value, beacon_size, raw_size_to_cache, window_size) |
|
past_key_values.append(layer_past_key_values) |
|
|
|
|
|
self._beacon_sizes.append(beacon_size) |
|
|
|
self.start_idx = next_start_idx |
|
self.end_idx = end_idx |
|
self.step_idx += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return input_ids, attention_mask, past_key_values, labels |
|
|
|
def update_memory(self, past_key_values): |
|
""" |
|
Accumulate beacon activations and raw activations. |
|
""" |
|
for layer_idx, (key, value, beacon_size, raw_size_to_cache, window_size) in enumerate(past_key_values): |
|
|
|
|
|
|
|
|
|
|
|
|
|
previous_beacon_key, previous_beacon_value = self.beacon_activations[layer_idx] |
|
previous_raw_key, previous_raw_value = self.raw_activations[layer_idx] |
|
|
|
if beacon_size == 0: |
|
|
|
|
|
beacon_key = previous_beacon_key |
|
beacon_value = previous_beacon_value |
|
|
|
assert raw_size_to_cache == window_size |
|
raw_key = cat_tensor([ |
|
previous_raw_key, |
|
key |
|
], dim=self.k_seq_dim) |
|
raw_value = cat_tensor([ |
|
previous_raw_value, |
|
value |
|
], dim=self.v_seq_dim) |
|
|
|
elif beacon_size == -1: |
|
|
|
|
|
if raw_size_to_cache > 0: |
|
|
|
concat_key = cat_tensor([ |
|
previous_raw_key, |
|
key |
|
], dim=self.k_seq_dim) |
|
concat_value = cat_tensor([ |
|
previous_raw_value, |
|
value |
|
], dim=self.v_seq_dim) |
|
|
|
beacon_key = cat_tensor([ |
|
previous_beacon_key, |
|
slice_tensor(concat_key, end=-raw_size_to_cache, dim=self.k_seq_dim) |
|
], dim=self.k_seq_dim) |
|
beacon_value = cat_tensor([ |
|
previous_beacon_value, |
|
slice_tensor(concat_value, end=-raw_size_to_cache, dim=self.v_seq_dim) |
|
], dim=self.v_seq_dim) |
|
raw_key = slice_tensor(concat_key, start=-raw_size_to_cache, dim=self.k_seq_dim) |
|
raw_value = slice_tensor(concat_value, start=-raw_size_to_cache, dim=self.v_seq_dim) |
|
|
|
else: |
|
|
|
beacon_key = cat_tensor([ |
|
previous_beacon_key, |
|
key, |
|
], dim=self.k_seq_dim) |
|
beacon_value = cat_tensor([ |
|
previous_beacon_value, |
|
value, |
|
], dim=self.v_seq_dim) |
|
raw_key = None |
|
raw_value = None |
|
|
|
else: |
|
|
|
|
|
|
|
beacon_key = cat_tensor([ |
|
previous_beacon_key, |
|
slice_tensor(key, start=-beacon_size, dim=self.k_seq_dim) |
|
], dim=self.k_seq_dim) |
|
beacon_value = cat_tensor([ |
|
previous_beacon_value, |
|
slice_tensor(value, start=-beacon_size, dim=self.v_seq_dim) |
|
], dim=self.v_seq_dim) |
|
|
|
if key.shape[self.k_seq_dim] < raw_size_to_cache + beacon_size: |
|
concat_raw_key = cat_tensor([ |
|
previous_raw_key, |
|
slice_tensor(key, end=-beacon_size, dim=self.k_seq_dim) |
|
], dim=self.k_seq_dim) |
|
concat_raw_value = cat_tensor([ |
|
previous_raw_value, |
|
slice_tensor(value, end=-beacon_size, dim=self.v_seq_dim) |
|
], dim=self.v_seq_dim) |
|
raw_key = slice_tensor(concat_raw_key, start=-raw_size_to_cache, dim=self.k_seq_dim) |
|
raw_value = slice_tensor(concat_raw_value, start=-raw_size_to_cache, dim=self.v_seq_dim) |
|
else: |
|
|
|
raw_key = slice_tensor(key, start=-raw_size_to_cache - beacon_size, end=-beacon_size, dim=self.k_seq_dim) |
|
raw_value = slice_tensor(value, start=-raw_size_to_cache - beacon_size, end=-beacon_size, dim=self.v_seq_dim) |
|
|
|
self.beacon_activations[layer_idx] = (beacon_key, beacon_value) |
|
self.raw_activations[layer_idx] = (raw_key, raw_value) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def update_loss(self, batch_loss, valid_token_num): |
|
""" |
|
Accumulate loss for later perplexity computation and backward pass; past_key_values according to cache_method. |
|
""" |
|
if self.batch_loss is None: |
|
|
|
self.batch_loss = batch_loss * valid_token_num |
|
self.valid_token_num = valid_token_num |
|
else: |
|
|
|
self.batch_loss = self.batch_loss + batch_loss * valid_token_num |
|
self.valid_token_num = self.valid_token_num + valid_token_num |
|
|
|
def output(self, model_outputs): |
|
""" |
|
Override loss with accumulated loss. |
|
""" |
|
|
|
if self.batch_loss is not None: |
|
|
|
loss = self.batch_loss.sum() / self.valid_token_num.sum() |
|
|
|
|
|
batch_loss = self.batch_loss / self.valid_token_num |
|
if (self.valid_token_num == 0).any(): |
|
batch_loss = batch_loss.masked_fill(self.valid_token_num == 0, 0.) |
|
|
|
|
|
model_outputs["loss"] = loss |
|
model_outputs["batch_loss"] = batch_loss |
|
model_outputs["valid_token_num"] = self.valid_token_num |
|
|
|
|
|
beacon_size = self._beacon_sizes[-1] |
|
|
|
if beacon_size > 0: |
|
model_outputs["logits"] = model_outputs["logits"][:, :-beacon_size] |
|
|
|
|
|
|
|
return model_outputs |
|
|
|
|
|
def slice_tensor(x, start=None, end=None, dim=2): |
|
if x is None: |
|
return None |
|
if end == 0: |
|
return None |
|
if start == x.shape[dim]: |
|
return None |
|
if start == end: |
|
return None |
|
if dim == 2: |
|
if start is None and end is not None: |
|
return x[:, :, :end, ...] |
|
elif start is not None and end is None: |
|
return x[:, :, start:, ...] |
|
elif start is not None and end is not None: |
|
return x[:, :, start:end, ...] |
|
elif dim == 1: |
|
if start is None and end is not None: |
|
return x[:, :end, ...] |
|
elif start is not None and end is None: |
|
return x[:, start:, ...] |
|
elif start is not None and end is not None: |
|
return x[:, start:end, ...] |
|
else: |
|
raise NotImplementedError |
|
|
|
def cat_tensor(list_of_tensors, dim=-1): |
|
list_of_tensors = [t for t in list_of_tensors if t is not None] |
|
if len(list_of_tensors) > 1: |
|
result = torch.cat(list_of_tensors, dim=dim) |
|
elif len(list_of_tensors) == 1: |
|
result = list_of_tensors[0] |
|
else: |
|
result = None |
|
return result |
|
|
|
def softmax(x:np.ndarray, axis=-1, temperature=1): |
|
if isinstance(x, list): |
|
x = np.array(x) |
|
x = x / temperature |
|
x = x - x.max(axis=axis, keepdims=True) |
|
y = np.exp(x) |
|
return y / y.sum(axis=axis, keepdims=True) |
|
|
|
def l1_norm(x): |
|
sum_x = sum(x) |
|
x = [y/sum_x for y in x] |
|
return x |
|
|