File size: 4,295 Bytes
0c0f76e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
from typing import Callable, Optional
import torch
from einops import rearrange
from diffusers.models.attention_processor import Attention
from diffusers.utils.import_utils import is_xformers_available
if is_xformers_available:
import xformers
import xformers.ops
else:
xformers = None
class CrossViewAttnProcessor:
def __init__(self, num_views: int = 1):
self.num_views = num_views
def __call__(
self,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
):
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states)
is_cross_attention = encoder_hidden_states is not None
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if not is_cross_attention and self.num_views > 1:
query = rearrange(query, "(b n) l d -> b (n l) d", n=self.num_views)
key = rearrange(key, "(b n) l d -> b (n l) d", n=self.num_views)
value = rearrange(value, "(b n) l d -> b (n l) d", n=self.num_views)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if not is_cross_attention and self.num_views > 1:
hidden_states = rearrange(hidden_states, "b (n l) d -> (b n) l d", n=self.num_views)
return hidden_states
class XFormersCrossViewAttnProcessor:
def __init__(
self,
num_views: int = 1,
attention_op: Optional[Callable] = None,
):
self.num_views = num_views
self.attention_op = attention_op
def __call__(
self,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
):
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states)
is_cross_attention = encoder_hidden_states is not None
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if not is_cross_attention and self.num_views > 1:
query = rearrange(query, "(b n) l d -> b (n l) d", n=self.num_views)
key = rearrange(key, "(b n) l d -> b (n l) d", n=self.num_views)
value = rearrange(value, "(b n) l d -> b (n l) d", n=self.num_views)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if not is_cross_attention and self.num_views > 1:
hidden_states = rearrange(hidden_states, "b (n l) d -> (b n) l d", n=self.num_views)
return hidden_states
|