# ------------------------------------------------------------------------ # # Ultimate VAE Tile Optimization # # Introducing a revolutionary new optimization designed to make # the VAE work with giant images on limited VRAM! # Say goodbye to the frustration of OOM and hello to seamless output! # # ------------------------------------------------------------------------ # # This script is a wild hack that splits the image into tiles, # encodes each tile separately, and merges the result back together. # # Advantages: # - The VAE can now work with giant images on limited VRAM # (~10 GB for 8K images!) # - The merged output is completely seamless without any post-processing. # # Drawbacks: # - Giant RAM needed. To store the intermediate results for a 4096x4096 # images, you need 32 GB RAM it consumes ~20GB); for 8192x8192 # you need 128 GB RAM machine (it consumes ~100 GB) # - NaNs always appear in for 8k images when you use fp16 (half) VAE # You must use --no-half-vae to disable half VAE for that giant image. # - Slow speed. With default tile size, it takes around 50/200 seconds # to encode/decode a 4096x4096 image; and 200/900 seconds to encode/decode # a 8192x8192 image. (The speed is limited by both the GPU and the CPU.) # - The gradient calculation is not compatible with this hack. It # will break any backward() or torch.autograd.grad() that passes VAE. # (But you can still use the VAE to generate training data.) # # How it works: # 1) The image is split into tiles. # - To ensure perfect results, each tile is padded with 32 pixels # on each side. # - Then the conv2d/silu/upsample/downsample can produce identical # results to the original image without splitting. # 2) The original forward is decomposed into a task queue and a task worker. # - The task queue is a list of functions that will be executed in order. # - The task worker is a loop that executes the tasks in the queue. # 3) The task queue is executed for each tile. # - Current tile is sent to GPU. # - local operations are directly executed. # - Group norm calculation is temporarily suspended until the mean # and var of all tiles are calculated. # - The residual is pre-calculated and stored and addded back later. # - When need to go to the next tile, the current tile is send to cpu. # 4) After all tiles are processed, tiles are merged on cpu and return. # # Enjoy! # # @author: LI YI @ Nanyang Technological University - Singapore # @date: 2023-03-02 # @license: MIT License # # Please give me a star if you like this project! # # ------------------------------------------------------------------------- import gc from time import time import math from tqdm import tqdm import torch import torch.version import torch.nn.functional as F from einops import rearrange from diffusers.utils.import_utils import is_xformers_available import SUPIR.utils.devices as devices try: import xformers import xformers.ops except ImportError: pass sd_flag = True def get_recommend_encoder_tile_size(): if torch.cuda.is_available(): total_memory = torch.cuda.get_device_properties( devices.device).total_memory // 2**20 if total_memory > 16*1000: ENCODER_TILE_SIZE = 3072 elif total_memory > 12*1000: ENCODER_TILE_SIZE = 2048 elif total_memory > 8*1000: ENCODER_TILE_SIZE = 1536 else: ENCODER_TILE_SIZE = 960 else: ENCODER_TILE_SIZE = 512 return ENCODER_TILE_SIZE def get_recommend_decoder_tile_size(): if torch.cuda.is_available(): total_memory = torch.cuda.get_device_properties( devices.device).total_memory // 2**20 if total_memory > 30*1000: DECODER_TILE_SIZE = 256 elif total_memory > 16*1000: DECODER_TILE_SIZE = 192 elif total_memory > 12*1000: DECODER_TILE_SIZE = 128 elif total_memory > 8*1000: DECODER_TILE_SIZE = 96 else: DECODER_TILE_SIZE = 64 else: DECODER_TILE_SIZE = 64 return DECODER_TILE_SIZE if 'global const': DEFAULT_ENABLED = False DEFAULT_MOVE_TO_GPU = False DEFAULT_FAST_ENCODER = True DEFAULT_FAST_DECODER = True DEFAULT_COLOR_FIX = 0 DEFAULT_ENCODER_TILE_SIZE = get_recommend_encoder_tile_size() DEFAULT_DECODER_TILE_SIZE = get_recommend_decoder_tile_size() # inplace version of silu def inplace_nonlinearity(x): # Test: fix for Nans return F.silu(x, inplace=True) # extracted from ldm.modules.diffusionmodules.model # from diffusers lib def attn_forward_new(self, h_): batch_size, channel, height, width = h_.shape hidden_states = h_.view(batch_size, channel, height * width).transpose(1, 2) attention_mask = None encoder_hidden_states = None batch_size, sequence_length, _ = hidden_states.shape attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = self.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif self.norm_cross: encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) key = self.to_k(encoder_hidden_states) value = self.to_v(encoder_hidden_states) query = self.head_to_batch_dim(query) key = self.head_to_batch_dim(key) value = self.head_to_batch_dim(value) attention_probs = self.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = self.batch_to_head_dim(hidden_states) # linear proj hidden_states = self.to_out[0](hidden_states) # dropout hidden_states = self.to_out[1](hidden_states) hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) return hidden_states def attn_forward_new_pt2_0(self, hidden_states,): scale = 1 attention_mask = None encoder_hidden_states = None input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, self.heads, -1, attention_mask.shape[-1]) if self.group_norm is not None: hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = self.to_q(hidden_states, scale=scale) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif self.norm_cross: encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) key = self.to_k(encoder_hidden_states, scale=scale) value = self.to_v(encoder_hidden_states, scale=scale) inner_dim = key.shape[-1] head_dim = inner_dim // self.heads query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = self.to_out[0](hidden_states, scale=scale) # dropout hidden_states = self.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) return hidden_states def attn_forward_new_xformers(self, hidden_states): scale = 1 attention_op = None attention_mask = None encoder_hidden_states = None input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, key_tokens, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = self.prepare_attention_mask(attention_mask, key_tokens, batch_size) if attention_mask is not None: # expand our mask's singleton query_tokens dimension: # [batch*heads, 1, key_tokens] -> # [batch*heads, query_tokens, key_tokens] # so that it can be added as a bias onto the attention scores that xformers computes: # [batch*heads, query_tokens, key_tokens] # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. _, query_tokens, _ = hidden_states.shape attention_mask = attention_mask.expand(-1, query_tokens, -1) if self.group_norm is not None: hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = self.to_q(hidden_states, scale=scale) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif self.norm_cross: encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) key = self.to_k(encoder_hidden_states, scale=scale) value = self.to_v(encoder_hidden_states, scale=scale) query = self.head_to_batch_dim(query).contiguous() key = self.head_to_batch_dim(key).contiguous() value = self.head_to_batch_dim(value).contiguous() hidden_states = xformers.ops.memory_efficient_attention( query, key, value, attn_bias=attention_mask, op=attention_op#, scale=scale ) hidden_states = hidden_states.to(query.dtype) hidden_states = self.batch_to_head_dim(hidden_states) # linear proj hidden_states = self.to_out[0](hidden_states, scale=scale) # dropout hidden_states = self.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) return hidden_states def attn_forward(self, h_): q = self.q(h_) k = self.k(h_) v = self.v(h_) # compute attention b, c, h, w = q.shape q = q.reshape(b, c, h*w) q = q.permute(0, 2, 1) # b,hw,c k = k.reshape(b, c, h*w) # b,c,hw w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] w_ = w_ * (int(c)**(-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values v = v.reshape(b, c, h*w) w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] h_ = torch.bmm(v, w_) h_ = h_.reshape(b, c, h, w) h_ = self.proj_out(h_) return h_ def xformer_attn_forward(self, h_): q = self.q(h_) k = self.k(h_) v = self.v(h_) # compute attention B, C, H, W = q.shape q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) q, k, v = map( lambda t: t.unsqueeze(3) .reshape(B, t.shape[1], 1, C) .permute(0, 2, 1, 3) .reshape(B * 1, t.shape[1], C) .contiguous(), (q, k, v), ) out = xformers.ops.memory_efficient_attention( q, k, v, attn_bias=None, op=self.attention_op) out = ( out.unsqueeze(0) .reshape(B, 1, out.shape[1], C) .permute(0, 2, 1, 3) .reshape(B, out.shape[1], C) ) out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) out = self.proj_out(out) return out def attn2task(task_queue, net): if False: #isinstance(net, AttnBlock): task_queue.append(('store_res', lambda x: x)) task_queue.append(('pre_norm', net.norm)) task_queue.append(('attn', lambda x, net=net: attn_forward(net, x))) task_queue.append(['add_res', None]) elif False: #isinstance(net, MemoryEfficientAttnBlock): task_queue.append(('store_res', lambda x: x)) task_queue.append(('pre_norm', net.norm)) task_queue.append( ('attn', lambda x, net=net: xformer_attn_forward(net, x))) task_queue.append(['add_res', None]) else: task_queue.append(('store_res', lambda x: x)) task_queue.append(('pre_norm', net.norm)) if is_xformers_available: # task_queue.append(('attn', lambda x, net=net: attn_forward_new_xformers(net, x))) task_queue.append( ('attn', lambda x, net=net: xformer_attn_forward(net, x))) elif hasattr(F, "scaled_dot_product_attention"): task_queue.append(('attn', lambda x, net=net: attn_forward_new_pt2_0(net, x))) else: task_queue.append(('attn', lambda x, net=net: attn_forward_new(net, x))) task_queue.append(['add_res', None]) def resblock2task(queue, block): """ Turn a ResNetBlock into a sequence of tasks and append to the task queue @param queue: the target task queue @param block: ResNetBlock """ if block.in_channels != block.out_channels: if sd_flag: if block.use_conv_shortcut: queue.append(('store_res', block.conv_shortcut)) else: queue.append(('store_res', block.nin_shortcut)) else: if block.use_in_shortcut: queue.append(('store_res', block.conv_shortcut)) else: queue.append(('store_res', block.nin_shortcut)) else: queue.append(('store_res', lambda x: x)) queue.append(('pre_norm', block.norm1)) queue.append(('silu', inplace_nonlinearity)) queue.append(('conv1', block.conv1)) queue.append(('pre_norm', block.norm2)) queue.append(('silu', inplace_nonlinearity)) queue.append(('conv2', block.conv2)) queue.append(['add_res', None]) def build_sampling(task_queue, net, is_decoder): """ Build the sampling part of a task queue @param task_queue: the target task queue @param net: the network @param is_decoder: currently building decoder or encoder """ if is_decoder: if sd_flag: resblock2task(task_queue, net.mid.block_1) attn2task(task_queue, net.mid.attn_1) print(task_queue) resblock2task(task_queue, net.mid.block_2) resolution_iter = reversed(range(net.num_resolutions)) block_ids = net.num_res_blocks + 1 condition = 0 module = net.up func_name = 'upsample' else: resblock2task(task_queue, net.mid_block.resnets[0]) attn2task(task_queue, net.mid_block.attentions[0]) resblock2task(task_queue, net.mid_block.resnets[1]) resolution_iter = (range(len(net.up_blocks))) # net.num_resolutions = 3 block_ids = 2 + 1 condition = len(net.up_blocks) - 1 module = net.up_blocks func_name = 'upsamplers' else: if sd_flag: resolution_iter = range(net.num_resolutions) block_ids = net.num_res_blocks condition = net.num_resolutions - 1 module = net.down func_name = 'downsample' else: resolution_iter = range(len(net.down_blocks)) block_ids = 2 condition = len(net.down_blocks) - 1 module = net.down_blocks func_name = 'downsamplers' for i_level in resolution_iter: for i_block in range(block_ids): if sd_flag: resblock2task(task_queue, module[i_level].block[i_block]) else: resblock2task(task_queue, module[i_level].resnets[i_block]) if i_level != condition: if sd_flag: task_queue.append((func_name, getattr(module[i_level], func_name))) else: if is_decoder: task_queue.append((func_name, module[i_level].upsamplers[0])) else: task_queue.append((func_name, module[i_level].downsamplers[0])) if not is_decoder: if sd_flag: resblock2task(task_queue, net.mid.block_1) attn2task(task_queue, net.mid.attn_1) resblock2task(task_queue, net.mid.block_2) else: resblock2task(task_queue, net.mid_block.resnets[0]) attn2task(task_queue, net.mid_block.attentions[0]) resblock2task(task_queue, net.mid_block.resnets[1]) def build_task_queue(net, is_decoder): """ Build a single task queue for the encoder or decoder @param net: the VAE decoder or encoder network @param is_decoder: currently building decoder or encoder @return: the task queue """ task_queue = [] task_queue.append(('conv_in', net.conv_in)) # construct the sampling part of the task queue # because encoder and decoder share the same architecture, we extract the sampling part build_sampling(task_queue, net, is_decoder) if is_decoder and not sd_flag: net.give_pre_end = False net.tanh_out = False if not is_decoder or not net.give_pre_end: if sd_flag: task_queue.append(('pre_norm', net.norm_out)) else: task_queue.append(('pre_norm', net.conv_norm_out)) task_queue.append(('silu', inplace_nonlinearity)) task_queue.append(('conv_out', net.conv_out)) if is_decoder and net.tanh_out: task_queue.append(('tanh', torch.tanh)) return task_queue def clone_task_queue(task_queue): """ Clone a task queue @param task_queue: the task queue to be cloned @return: the cloned task queue """ return [[item for item in task] for task in task_queue] def get_var_mean(input, num_groups, eps=1e-6): """ Get mean and var for group norm """ b, c = input.size(0), input.size(1) channel_in_group = int(c/num_groups) input_reshaped = input.contiguous().view( 1, int(b * num_groups), channel_in_group, *input.size()[2:]) var, mean = torch.var_mean( input_reshaped, dim=[0, 2, 3, 4], unbiased=False) return var, mean def custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps=1e-6): """ Custom group norm with fixed mean and var @param input: input tensor @param num_groups: number of groups. by default, num_groups = 32 @param mean: mean, must be pre-calculated by get_var_mean @param var: var, must be pre-calculated by get_var_mean @param weight: weight, should be fetched from the original group norm @param bias: bias, should be fetched from the original group norm @param eps: epsilon, by default, eps = 1e-6 to match the original group norm @return: normalized tensor """ b, c = input.size(0), input.size(1) channel_in_group = int(c/num_groups) input_reshaped = input.contiguous().view( 1, int(b * num_groups), channel_in_group, *input.size()[2:]) out = F.batch_norm(input_reshaped, mean, var, weight=None, bias=None, training=False, momentum=0, eps=eps) out = out.view(b, c, *input.size()[2:]) # post affine transform if weight is not None: out *= weight.view(1, -1, 1, 1) if bias is not None: out += bias.view(1, -1, 1, 1) return out def crop_valid_region(x, input_bbox, target_bbox, is_decoder): """ Crop the valid region from the tile @param x: input tile @param input_bbox: original input bounding box @param target_bbox: output bounding box @param scale: scale factor @return: cropped tile """ padded_bbox = [i * 8 if is_decoder else i//8 for i in input_bbox] margin = [target_bbox[i] - padded_bbox[i] for i in range(4)] return x[:, :, margin[2]:x.size(2)+margin[3], margin[0]:x.size(3)+margin[1]] # ↓↓↓ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↓↓↓ def perfcount(fn): def wrapper(*args, **kwargs): ts = time() if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats(devices.device) devices.torch_gc() gc.collect() ret = fn(*args, **kwargs) devices.torch_gc() gc.collect() if torch.cuda.is_available(): vram = torch.cuda.max_memory_allocated(devices.device) / 2**20 torch.cuda.reset_peak_memory_stats(devices.device) print( f'[Tiled VAE]: Done in {time() - ts:.3f}s, max VRAM alloc {vram:.3f} MB') else: print(f'[Tiled VAE]: Done in {time() - ts:.3f}s') return ret return wrapper # copy end :) class GroupNormParam: def __init__(self): self.var_list = [] self.mean_list = [] self.pixel_list = [] self.weight = None self.bias = None def add_tile(self, tile, layer): var, mean = get_var_mean(tile, 32) # For giant images, the variance can be larger than max float16 # In this case we create a copy to float32 if var.dtype == torch.float16 and var.isinf().any(): fp32_tile = tile.float() var, mean = get_var_mean(fp32_tile, 32) # ============= DEBUG: test for infinite ============= # if torch.isinf(var).any(): # print('var: ', var) # ==================================================== self.var_list.append(var) self.mean_list.append(mean) self.pixel_list.append( tile.shape[2]*tile.shape[3]) if hasattr(layer, 'weight'): self.weight = layer.weight self.bias = layer.bias else: self.weight = None self.bias = None def summary(self): """ summarize the mean and var and return a function that apply group norm on each tile """ if len(self.var_list) == 0: return None var = torch.vstack(self.var_list) mean = torch.vstack(self.mean_list) max_value = max(self.pixel_list) pixels = torch.tensor( self.pixel_list, dtype=torch.float32, device=devices.device) / max_value sum_pixels = torch.sum(pixels) pixels = pixels.unsqueeze( 1) / sum_pixels var = torch.sum( var * pixels, dim=0) mean = torch.sum( mean * pixels, dim=0) return lambda x: custom_group_norm(x, 32, mean, var, self.weight, self.bias) @staticmethod def from_tile(tile, norm): """ create a function from a single tile without summary """ var, mean = get_var_mean(tile, 32) if var.dtype == torch.float16 and var.isinf().any(): fp32_tile = tile.float() var, mean = get_var_mean(fp32_tile, 32) # if it is a macbook, we need to convert back to float16 if var.device.type == 'mps': # clamp to avoid overflow var = torch.clamp(var, 0, 60000) var = var.half() mean = mean.half() if hasattr(norm, 'weight'): weight = norm.weight bias = norm.bias else: weight = None bias = None def group_norm_func(x, mean=mean, var=var, weight=weight, bias=bias): return custom_group_norm(x, 32, mean, var, weight, bias, 1e-6) return group_norm_func class VAEHook: def __init__(self, net, tile_size, is_decoder, fast_decoder, fast_encoder, color_fix, to_gpu=False): self.net = net # encoder | decoder self.tile_size = tile_size self.is_decoder = is_decoder self.fast_mode = (fast_encoder and not is_decoder) or ( fast_decoder and is_decoder) self.color_fix = color_fix and not is_decoder self.to_gpu = to_gpu self.pad = 11 if is_decoder else 32 def __call__(self, x): B, C, H, W = x.shape original_device = next(self.net.parameters()).device try: if self.to_gpu: self.net.to(devices.get_optimal_device()) if max(H, W) <= self.pad * 2 + self.tile_size: print("[Tiled VAE]: the input size is tiny and unnecessary to tile.") return self.net.original_forward(x) else: return self.vae_tile_forward(x) finally: self.net.to(original_device) def get_best_tile_size(self, lowerbound, upperbound): """ Get the best tile size for GPU memory """ divider = 32 while divider >= 2: remainer = lowerbound % divider if remainer == 0: return lowerbound candidate = lowerbound - remainer + divider if candidate <= upperbound: return candidate divider //= 2 return lowerbound def split_tiles(self, h, w): """ Tool function to split the image into tiles @param h: height of the image @param w: width of the image @return: tile_input_bboxes, tile_output_bboxes """ tile_input_bboxes, tile_output_bboxes = [], [] tile_size = self.tile_size pad = self.pad num_height_tiles = math.ceil((h - 2 * pad) / tile_size) num_width_tiles = math.ceil((w - 2 * pad) / tile_size) # If any of the numbers are 0, we let it be 1 # This is to deal with long and thin images num_height_tiles = max(num_height_tiles, 1) num_width_tiles = max(num_width_tiles, 1) # Suggestions from https://github.com/Kahsolt: auto shrink the tile size real_tile_height = math.ceil((h - 2 * pad) / num_height_tiles) real_tile_width = math.ceil((w - 2 * pad) / num_width_tiles) real_tile_height = self.get_best_tile_size(real_tile_height, tile_size) real_tile_width = self.get_best_tile_size(real_tile_width, tile_size) print(f'[Tiled VAE]: split to {num_height_tiles}x{num_width_tiles} = {num_height_tiles*num_width_tiles} tiles. ' + f'Optimal tile size {real_tile_width}x{real_tile_height}, original tile size {tile_size}x{tile_size}') for i in range(num_height_tiles): for j in range(num_width_tiles): # bbox: [x1, x2, y1, y2] # the padding is is unnessary for image borders. So we directly start from (32, 32) input_bbox = [ pad + j * real_tile_width, min(pad + (j + 1) * real_tile_width, w), pad + i * real_tile_height, min(pad + (i + 1) * real_tile_height, h), ] # if the output bbox is close to the image boundary, we extend it to the image boundary output_bbox = [ input_bbox[0] if input_bbox[0] > pad else 0, input_bbox[1] if input_bbox[1] < w - pad else w, input_bbox[2] if input_bbox[2] > pad else 0, input_bbox[3] if input_bbox[3] < h - pad else h, ] # scale to get the final output bbox output_bbox = [x * 8 if self.is_decoder else x // 8 for x in output_bbox] tile_output_bboxes.append(output_bbox) # indistinguishable expand the input bbox by pad pixels tile_input_bboxes.append([ max(0, input_bbox[0] - pad), min(w, input_bbox[1] + pad), max(0, input_bbox[2] - pad), min(h, input_bbox[3] + pad), ]) return tile_input_bboxes, tile_output_bboxes @torch.no_grad() def estimate_group_norm(self, z, task_queue, color_fix): device = z.device tile = z last_id = len(task_queue) - 1 while last_id >= 0 and task_queue[last_id][0] != 'pre_norm': last_id -= 1 if last_id <= 0 or task_queue[last_id][0] != 'pre_norm': raise ValueError('No group norm found in the task queue') # estimate until the last group norm for i in range(last_id + 1): task = task_queue[i] if task[0] == 'pre_norm': group_norm_func = GroupNormParam.from_tile(tile, task[1]) task_queue[i] = ('apply_norm', group_norm_func) if i == last_id: return True tile = group_norm_func(tile) elif task[0] == 'store_res': task_id = i + 1 while task_id < last_id and task_queue[task_id][0] != 'add_res': task_id += 1 if task_id >= last_id: continue task_queue[task_id][1] = task[1](tile) elif task[0] == 'add_res': tile += task[1].to(device) task[1] = None elif color_fix and task[0] == 'downsample': for j in range(i, last_id + 1): if task_queue[j][0] == 'store_res': task_queue[j] = ('store_res_cpu', task_queue[j][1]) return True else: tile = task[1](tile) try: devices.test_for_nans(tile, "vae") except: print(f'Nan detected in fast mode estimation. Fast mode disabled.') return False raise IndexError('Should not reach here') @perfcount @torch.no_grad() def vae_tile_forward(self, z): """ Decode a latent vector z into an image in a tiled manner. @param z: latent vector @return: image """ device = next(self.net.parameters()).device dtype = z.dtype net = self.net tile_size = self.tile_size is_decoder = self.is_decoder z = z.detach() # detach the input to avoid backprop N, height, width = z.shape[0], z.shape[2], z.shape[3] net.last_z_shape = z.shape # Split the input into tiles and build a task queue for each tile print(f'[Tiled VAE]: input_size: {z.shape}, tile_size: {tile_size}, padding: {self.pad}') in_bboxes, out_bboxes = self.split_tiles(height, width) # Prepare tiles by split the input latents tiles = [] for input_bbox in in_bboxes: tile = z[:, :, input_bbox[2]:input_bbox[3], input_bbox[0]:input_bbox[1]].cpu() tiles.append(tile) num_tiles = len(tiles) num_completed = 0 # Build task queues single_task_queue = build_task_queue(net, is_decoder) #print(single_task_queue) if self.fast_mode: # Fast mode: downsample the input image to the tile size, # then estimate the group norm parameters on the downsampled image scale_factor = tile_size / max(height, width) z = z.to(device) downsampled_z = F.interpolate(z, scale_factor=scale_factor, mode='nearest-exact') # use nearest-exact to keep statictics as close as possible print(f'[Tiled VAE]: Fast mode enabled, estimating group norm parameters on {downsampled_z.shape[3]} x {downsampled_z.shape[2]} image') # ======= Special thanks to @Kahsolt for distribution shift issue ======= # # The downsampling will heavily distort its mean and std, so we need to recover it. std_old, mean_old = torch.std_mean(z, dim=[0, 2, 3], keepdim=True) std_new, mean_new = torch.std_mean(downsampled_z, dim=[0, 2, 3], keepdim=True) downsampled_z = (downsampled_z - mean_new) / std_new * std_old + mean_old del std_old, mean_old, std_new, mean_new # occasionally the std_new is too small or too large, which exceeds the range of float16 # so we need to clamp it to max z's range. downsampled_z = torch.clamp_(downsampled_z, min=z.min(), max=z.max()) estimate_task_queue = clone_task_queue(single_task_queue) if self.estimate_group_norm(downsampled_z, estimate_task_queue, color_fix=self.color_fix): single_task_queue = estimate_task_queue del downsampled_z task_queues = [clone_task_queue(single_task_queue) for _ in range(num_tiles)] # Dummy result result = None result_approx = None #try: # with devices.autocast(): # result_approx = torch.cat([F.interpolate(cheap_approximation(x).unsqueeze(0), scale_factor=opt_f, mode='nearest-exact') for x in z], dim=0).cpu() #except: pass # Free memory of input latent tensor del z # Task queue execution pbar = tqdm(total=num_tiles * len(task_queues[0]), desc=f"[Tiled VAE]: Executing {'Decoder' if is_decoder else 'Encoder'} Task Queue: ") # execute the task back and forth when switch tiles so that we always # keep one tile on the GPU to reduce unnecessary data transfer forward = True interrupted = False #state.interrupted = interrupted while True: #if state.interrupted: interrupted = True ; break group_norm_param = GroupNormParam() for i in range(num_tiles) if forward else reversed(range(num_tiles)): #if state.interrupted: interrupted = True ; break tile = tiles[i].to(device) input_bbox = in_bboxes[i] task_queue = task_queues[i] interrupted = False while len(task_queue) > 0: #if state.interrupted: interrupted = True ; break # DEBUG: current task # print('Running task: ', task_queue[0][0], ' on tile ', i, '/', num_tiles, ' with shape ', tile.shape) task = task_queue.pop(0) if task[0] == 'pre_norm': group_norm_param.add_tile(tile, task[1]) break elif task[0] == 'store_res' or task[0] == 'store_res_cpu': task_id = 0 res = task[1](tile) if not self.fast_mode or task[0] == 'store_res_cpu': res = res.cpu() while task_queue[task_id][0] != 'add_res': task_id += 1 task_queue[task_id][1] = res elif task[0] == 'add_res': tile += task[1].to(device) task[1] = None else: tile = task[1](tile) #print(tiles[i].shape, tile.shape, task) pbar.update(1) if interrupted: break # check for NaNs in the tile. # If there are NaNs, we abort the process to save user's time #devices.test_for_nans(tile, "vae") #print(tiles[i].shape, tile.shape, i, num_tiles) if len(task_queue) == 0: tiles[i] = None num_completed += 1 if result is None: # NOTE: dim C varies from different cases, can only be inited dynamically result = torch.zeros((N, tile.shape[1], height * 8 if is_decoder else height // 8, width * 8 if is_decoder else width // 8), device=device, requires_grad=False) result[:, :, out_bboxes[i][2]:out_bboxes[i][3], out_bboxes[i][0]:out_bboxes[i][1]] = crop_valid_region(tile, in_bboxes[i], out_bboxes[i], is_decoder) del tile elif i == num_tiles - 1 and forward: forward = False tiles[i] = tile elif i == 0 and not forward: forward = True tiles[i] = tile else: tiles[i] = tile.cpu() del tile if interrupted: break if num_completed == num_tiles: break # insert the group norm task to the head of each task queue group_norm_func = group_norm_param.summary() if group_norm_func is not None: for i in range(num_tiles): task_queue = task_queues[i] task_queue.insert(0, ('apply_norm', group_norm_func)) # Done! pbar.close() return result.to(dtype) if result is not None else result_approx.to(device)