Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,446 Bytes
2422035 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
# Modified from:
# gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
# DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch._dynamo.config
import torch._inductor.config
import copy
import time
# torch._inductor.config.coordinate_descent_tuning = True
# torch._inductor.config.triton.unique_kernel_names = True
# torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
### from https://huggingface.co/transformers/v3.2.0/_modules/transformers/generation_utils.html
def top_k_top_p_filtering(
logits,
top_k: int = 0,
top_p: float = 1.0,
filter_value: float = -float("Inf"),
min_tokens_to_keep: int = 1,
):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Make sure we keep at least min_tokens_to_keep per batch example in the output
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
if top_k > 0:
# import pdb;pdb.set_trace()
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits
def sample(logits, temperature: float=1.0, top_k: int=2000, top_p: float=1.0, sample_logits=True):
logits = logits[:, -1, :] / max(temperature, 1e-5)
if top_k > 0 or top_p < 1.0:
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
probs = F.softmax(logits, dim=-1)
# values, indices = torch.max(probs, dim=1, keepdim=True)
# mask = (probs == values).float()
# probs = probs * (1 - mask)
# values, indices = torch.max(probs, dim=1, keepdim=True)
# mask = (probs == values).float()
# probs = probs * (1 - mask)
if sample_logits:
idx = torch.multinomial(probs, num_samples=1)
else:
_, idx = torch.topk(probs, k=1, dim=-1)
return idx, probs
def logits_to_probs(logits, temperature: float = 1.0, top_p: float=1.0, top_k: int = None, **kwargs):
logits = logits / max(temperature, 1e-5)
if top_k > 0 or top_p < 1.0:
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
probs = torch.nn.functional.softmax(logits, dim=-1)
return probs
def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, condition:torch.Tensor, **sampling_kwargs):
if cfg_scale > 1.0:
logits, _ = model(None, cond_idx, input_pos, condition=condition)
logits_combined = logits
cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
else:
logits, _ = model(None, cond_idx, input_pos, condition=condition)
return sample(logits, **sampling_kwargs)[0]
def decode_one_token(model, x: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, cfg_flag: bool, condition: torch.Tensor, **sampling_kwargs):
assert input_pos.shape[-1] == 1
if cfg_scale > 1.0:
x_combined = torch.cat([x, x])
logits, _ = model(x_combined, cond_idx=None, input_pos=input_pos, condition=condition)
logits_combined = logits
cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
if cfg_flag:
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
else:
logits = cond_logits
else:
logits, _ = model(x, cond_idx=None, input_pos=input_pos, condition=None)
return sample(logits, **sampling_kwargs)
def decode_n_tokens(
model, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int,
cfg_scale: float, cfg_interval: int, condition: torch.Tensor,
**sampling_kwargs):
new_tokens, new_probs = [], []
cfg_flag = True
for i in range(num_new_tokens):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
if cfg_interval > -1 and i > cfg_interval:
cfg_flag = False
next_token, next_prob = decode_one_token(
model, cur_token, input_pos, cfg_scale, cfg_flag, condition=condition, **sampling_kwargs
)
input_pos += 1
new_tokens.append(next_token.clone())
new_probs.append(next_prob.clone())
cur_token = next_token.view(-1, 1)
return new_tokens, new_probs
@torch.no_grad()
def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, condition=None, condition_null=None, condition_token_nums=0, **sampling_kwargs):
if condition is not None:
condition = model.adapter(condition)
condition = model.adapter_mlp(condition)
if model.model_type == 'c2i':
if cfg_scale > 1.0:
cond_null = torch.ones_like(cond) * model.num_classes
cond_combined = torch.cat([cond, cond_null])
if condition is not None:
condition_null = torch.zeros_like(condition)
condition_combined = torch.cat((condition, condition_null), dim=0)
else:
condition_combined = None
else:
cond_combined = cond
if condition is not None:
condition_combined = condition
else:
condition_combined = None
T = 1+condition_token_nums
elif model.model_type == 't2i':
if cfg_scale > 1.0:
cond_null = torch.zeros_like(cond) + model.cls_embedding.uncond_embedding
cond_combined = torch.cat([cond, cond_null])
if condition is not None:
condition_null = torch.zeros_like(condition)
condition_combined = torch.cat((condition, condition_null), dim=0)
else:
condition_combined = None
else:
cond_combined = cond
if condition is not None:
condition_combined = condition
else:
condition_combined = None
T = cond.shape[1]
else:
raise Exception("please check model type")
T_new = T + max_new_tokens
max_seq_length = T_new
max_batch_size = cond.shape[0]
device = cond.device
with torch.device(device):
max_batch_size_cfg = max_batch_size * 2 if cfg_scale > 1.0 else max_batch_size
model.setup_caches(max_batch_size=max_batch_size_cfg, max_seq_length=max_seq_length, dtype=model.tok_embeddings.weight.dtype)
if emb_masks is not None:
assert emb_masks.shape[0] == max_batch_size
assert emb_masks.shape[-1] == T
if cfg_scale > 1.0:
model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * torch.cat([emb_masks, emb_masks]).unsqueeze(1)
else:
model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * emb_masks.unsqueeze(1)
eye_matrix = torch.eye(model.causal_mask.size(1), model.causal_mask.size(2), device=device)
model.causal_mask[:] = model.causal_mask * (1 - eye_matrix) + eye_matrix
# create an empty tensor of the expected final shape and fill in the current tokens
seq = torch.empty((max_batch_size, T_new), dtype=torch.int, device=device)
input_pos = torch.arange(0, T, device=device)
next_token = prefill(model, cond_combined, input_pos, cfg_scale, condition_combined, **sampling_kwargs)
seq[:, T:T+1] = next_token
input_pos = torch.tensor([T], device=device, dtype=torch.int)
generated_tokens, _ = decode_n_tokens(model, next_token, input_pos, max_new_tokens-1, cfg_scale, cfg_interval, condition=condition_combined, **sampling_kwargs)
seq[:, T+1:] = torch.cat(generated_tokens, dim=1)
return seq[:, T:]
|