Flux9665's picture
initial commit
6faeba1
raw
history blame
1.47 kB
import json
import torch
import torch.nn as nn
from Preprocessing.Codec.env import AttrDict
from Preprocessing.Codec.models import Encoder
from Preprocessing.Codec.models import Generator
from Preprocessing.Codec.models import Quantizer
class VQVAE(nn.Module):
def __init__(self,
config_path,
ckpt_path,
with_encoder=False):
super(VQVAE, self).__init__()
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
with open(config_path) as f:
data = f.read()
json_config = json.loads(data)
self.h = AttrDict(json_config)
self.quantizer = Quantizer(self.h)
self.generator = Generator(self.h)
self.generator.load_state_dict(ckpt['generator'])
self.quantizer.load_state_dict(ckpt['quantizer'])
if with_encoder:
self.encoder = Encoder(self.h)
self.encoder.load_state_dict(ckpt['encoder'])
def forward(self, x):
# x is the codebook
# x.shape (B, T, Nq)
quant_emb = self.quantizer.embed(x)
return self.generator(quant_emb)
def encode(self, x):
batch_size = x.size(0)
if len(x.shape) == 3 and x.shape[-1] == 1:
x = x.squeeze(-1)
c = self.encoder(x.unsqueeze(1))
q, loss_q, c = self.quantizer(c)
c = [code.reshape(batch_size, -1) for code in c]
# shape: [N, T, 4]
return torch.stack(c, -1)