File size: 2,292 Bytes
4552b82 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
from collections import namedtuple
from dataclasses import dataclass
import torch
from typing import Tuple, Optional
@dataclass
class LongLlamaMemConfig:
"""
Class for configuring memory caches for LongLlama model.
Args:
positionals (`boolean`)
Whether to use positional embeddings in memory layer
cache_dtype (`torch.dtype`)
Specifies storing type for keys and values
attention_grouping (`Tuple[int, int]`, *optional*)
One can trade speed for memory by performing attention
in memory layers sequentially.
When equal to `(4, 128)` the memory layers will process at most 4 heads and 128 queries
from each head at once. That is at most 512 queries at once.
"""
positionals: bool = True
cache_dtype: torch.dtype = torch.bfloat16
attention_grouping: Optional[Tuple[int, int]] = None
@dataclass
class LongLlamaMemCache:
"""
Class with LongLlama's memory cache
Args:
keys (`torch.FloatTensor` of shape `(batch_size, num_heads, mem_length, embed_size_per_head)`)
values (`torch.FloatTensor` of shape `(batch_size, num_heads, mem_length, embed_size_per_head)`)
masks (`torch.FloatTensor` of shape `(batch_size, 1, mem_length, 1)`)
For masking out parts of memory
"""
keys: torch.FloatTensor
values: torch.FloatTensor
masks: torch.FloatTensor
def mem_apply_update(
prev_mem_cache: LongLlamaMemCache, new_mem_content: LongLlamaMemCache, mem_config: LongLlamaMemConfig
):
def update_one(prev, new):
if len(prev.shape) != 4 or len(new.shape) != 4:
raise ValueError(f"Memory cache content should be consistent in shape got {prev.shape} {new.shape}")
return torch.concat([prev, new], dim=-2)
insert_size = new_mem_content.keys.shape[-2]
if new_mem_content.values.shape[-2] != insert_size or new_mem_content.masks.shape[-2] != insert_size:
raise ValueError(f"Inconsistent mem_length in new_mem_content")
return LongLlamaMemCache(
keys=update_one(prev_mem_cache.keys, new_mem_content.keys),
values=update_one(prev_mem_cache.values, new_mem_content.values),
masks=update_one(prev_mem_cache.masks, new_mem_content.masks),
)
|