shaokun's picture
WIP
a4d7b31
import functools
import torch
from diffusers.models.attention import BasicTransformerBlock
from diffusers.utils.import_utils import is_xformers_available
from .lora import LoraInjectedLinear
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
@functools.cache
def test_xformers_backwards(size):
@torch.enable_grad()
def _grad(size):
q = torch.randn((1, 4, size), device="cuda")
k = torch.randn((1, 4, size), device="cuda")
v = torch.randn((1, 4, size), device="cuda")
q = q.detach().requires_grad_()
k = k.detach().requires_grad_()
v = v.detach().requires_grad_()
out = xformers.ops.memory_efficient_attention(q, k, v)
loss = out.sum(2).mean(0).sum()
return torch.autograd.grad(loss, v)
try:
_grad(size)
print(size, "pass")
return True
except Exception as e:
print(size, "fail")
return False
def set_use_memory_efficient_attention_xformers(
module: torch.nn.Module, valid: bool
) -> None:
def fn_test_dim_head(module: torch.nn.Module):
if isinstance(module, BasicTransformerBlock):
# dim_head isn't stored anywhere, so back-calculate
source = module.attn1.to_v
if isinstance(source, LoraInjectedLinear):
source = source.linear
dim_head = source.out_features // module.attn1.heads
result = test_xformers_backwards(dim_head)
# If dim_head > dim_head_max, turn xformers off
if not result:
module.set_use_memory_efficient_attention_xformers(False)
for child in module.children():
fn_test_dim_head(child)
if not is_xformers_available() and valid:
print("XFormers is not available. Skipping.")
return
module.set_use_memory_efficient_attention_xformers(valid)
if valid:
fn_test_dim_head(module)