Spaces:
Running
on
A10G
Running
on
A10G
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
import math | |
from src import utils | |
from src.egnn import Dynamics | |
from src.noise import GammaNetwork, PredefinedNoiseSchedule | |
from typing import Union | |
from tqdm import tqdm | |
from pdb import set_trace | |
class EDM(torch.nn.Module): | |
def __init__( | |
self, | |
dynamics: Union[Dynamics], | |
in_node_nf: int, | |
n_dims: int, | |
timesteps: int = 1000, | |
noise_schedule='learned', | |
noise_precision=1e-4, | |
loss_type='vlb', | |
norm_values=(1., 1., 1.), | |
norm_biases=(None, 0., 0.), | |
): | |
super().__init__() | |
if noise_schedule == 'learned': | |
assert loss_type == 'vlb', 'A noise schedule can only be learned with a vlb objective' | |
self.gamma = GammaNetwork() | |
else: | |
self.gamma = PredefinedNoiseSchedule(noise_schedule, timesteps=timesteps, precision=noise_precision) | |
self.dynamics = dynamics | |
self.in_node_nf = in_node_nf | |
self.n_dims = n_dims | |
self.T = timesteps | |
self.norm_values = norm_values | |
self.norm_biases = norm_biases | |
def forward(self, x, h, node_mask, fragment_mask, linker_mask, edge_mask, context=None): | |
# Normalization and concatenation | |
x, h = self.normalize(x, h) | |
xh = torch.cat([x, h], dim=2) | |
# Volume change loss term | |
delta_log_px = self.delta_log_px(linker_mask).mean() | |
# Sample t | |
t_int = torch.randint(0, self.T + 1, size=(x.size(0), 1), device=x.device).float() | |
s_int = t_int - 1 | |
t = t_int / self.T | |
s = s_int / self.T | |
# Masks for t=0 and t>0 | |
t_is_zero = (t_int == 0).squeeze().float() | |
t_is_not_zero = 1 - t_is_zero | |
# Compute gamma_t and gamma_s according to the noise schedule | |
gamma_t = self.inflate_batch_array(self.gamma(t), x) | |
gamma_s = self.inflate_batch_array(self.gamma(s), x) | |
# Compute alpha_t and sigma_t from gamma | |
alpha_t = self.alpha(gamma_t, x) | |
sigma_t = self.sigma(gamma_t, x) | |
# Sample noise | |
# Note: only for linker | |
eps_t = self.sample_combined_position_feature_noise(n_samples=x.size(0), n_nodes=x.size(1), mask=linker_mask) | |
# Sample z_t given x, h for timestep t, from q(z_t | x, h) | |
# Note: keep fragments unchanged | |
z_t = alpha_t * xh + sigma_t * eps_t | |
z_t = xh * fragment_mask + z_t * linker_mask | |
# Neural net prediction | |
eps_t_hat = self.dynamics.forward( | |
xh=z_t, | |
t=t, | |
node_mask=node_mask, | |
linker_mask=linker_mask, | |
context=context, | |
edge_mask=edge_mask, | |
) | |
eps_t_hat = eps_t_hat * linker_mask | |
# Computing basic error (further used for computing NLL and L2-loss) | |
error_t = self.sum_except_batch((eps_t - eps_t_hat) ** 2) | |
# Computing L2-loss for t>0 | |
normalization = (self.n_dims + self.in_node_nf) * self.numbers_of_nodes(linker_mask) | |
l2_loss = error_t / normalization | |
l2_loss = l2_loss.mean() | |
# The KL between q(z_T | x) and p(z_T) = Normal(0, 1) (should be close to zero) | |
kl_prior = self.kl_prior(xh, linker_mask).mean() | |
# Computing NLL middle term | |
SNR_weight = (self.SNR(gamma_s - gamma_t) - 1).squeeze(1).squeeze(1) | |
loss_term_t = self.T * 0.5 * SNR_weight * error_t | |
loss_term_t = (loss_term_t * t_is_not_zero).sum() / t_is_not_zero.sum() | |
# Computing noise returned by dynamics | |
noise = torch.norm(eps_t_hat, dim=[1, 2]) | |
noise_t = (noise * t_is_not_zero).sum() / t_is_not_zero.sum() | |
if t_is_zero.sum() > 0: | |
# The _constants_ depending on sigma_0 from the | |
# cross entropy term E_q(z0 | x) [log p(x | z0)] | |
neg_log_constants = -self.log_constant_of_p_x_given_z0(x, linker_mask) | |
# Computes the L_0 term (even if gamma_t is not actually gamma_0) | |
# and selected only relevant via masking | |
loss_term_0 = -self.log_p_xh_given_z0_without_constants(h, z_t, gamma_t, eps_t, eps_t_hat, linker_mask) | |
loss_term_0 = loss_term_0 + neg_log_constants | |
loss_term_0 = (loss_term_0 * t_is_zero).sum() / t_is_zero.sum() | |
# Computing noise returned by dynamics | |
noise_0 = (noise * t_is_zero).sum() / t_is_zero.sum() | |
else: | |
loss_term_0 = 0. | |
noise_0 = 0. | |
return delta_log_px, kl_prior, loss_term_t, loss_term_0, l2_loss, noise_t, noise_0 | |
def sample_chain(self, x, h, node_mask, fragment_mask, linker_mask, edge_mask, context, keep_frames=None): | |
n_samples = x.size(0) | |
n_nodes = x.size(1) | |
# Normalization and concatenation | |
x, h, = self.normalize(x, h) | |
xh = torch.cat([x, h], dim=2) | |
# Initial linker sampling from N(0, I) | |
z = self.sample_combined_position_feature_noise(n_samples, n_nodes, mask=linker_mask) | |
z = xh * fragment_mask + z * linker_mask | |
if keep_frames is None: | |
keep_frames = self.T | |
else: | |
assert keep_frames <= self.T | |
chain = torch.zeros((keep_frames,) + z.size(), device=z.device) | |
# Sample p(z_s | z_t) | |
for s in tqdm(reversed(range(0, self.T)), total=self.T): | |
s_array = torch.full((n_samples, 1), fill_value=s, device=z.device) | |
t_array = s_array + 1 | |
s_array = s_array / self.T | |
t_array = t_array / self.T | |
z = self.sample_p_zs_given_zt_only_linker( | |
s=s_array, | |
t=t_array, | |
z_t=z, | |
node_mask=node_mask, | |
fragment_mask=fragment_mask, | |
linker_mask=linker_mask, | |
edge_mask=edge_mask, | |
context=context, | |
) | |
write_index = (s * keep_frames) // self.T | |
chain[write_index] = self.unnormalize_z(z) | |
# Finally sample p(x, h | z_0) | |
x, h = self.sample_p_xh_given_z0_only_linker( | |
z_0=z, | |
node_mask=node_mask, | |
fragment_mask=fragment_mask, | |
linker_mask=linker_mask, | |
edge_mask=edge_mask, | |
context=context, | |
) | |
chain[0] = torch.cat([x, h], dim=2) | |
return chain | |
def sample_p_zs_given_zt_only_linker(self, s, t, z_t, node_mask, fragment_mask, linker_mask, edge_mask, context): | |
"""Samples from zs ~ p(zs | zt). Only used during sampling. Samples only linker features and coords""" | |
gamma_s = self.gamma(s) | |
gamma_t = self.gamma(t) | |
sigma2_t_given_s, sigma_t_given_s, alpha_t_given_s = self.sigma_and_alpha_t_given_s(gamma_t, gamma_s, z_t) | |
sigma_s = self.sigma(gamma_s, target_tensor=z_t) | |
sigma_t = self.sigma(gamma_t, target_tensor=z_t) | |
# Neural net prediction. | |
eps_hat = self.dynamics.forward( | |
xh=z_t, | |
t=t, | |
node_mask=node_mask, | |
linker_mask=linker_mask, | |
context=context, | |
edge_mask=edge_mask, | |
) | |
eps_hat = eps_hat * linker_mask | |
# Compute mu for p(z_s | z_t) | |
mu = z_t / alpha_t_given_s - (sigma2_t_given_s / alpha_t_given_s / sigma_t) * eps_hat | |
# Compute sigma for p(z_s | z_t) | |
sigma = sigma_t_given_s * sigma_s / sigma_t | |
# Sample z_s given the parameters derived from zt | |
z_s = self.sample_normal(mu, sigma, linker_mask) | |
z_s = z_t * fragment_mask + z_s * linker_mask | |
return z_s | |
def sample_p_xh_given_z0_only_linker(self, z_0, node_mask, fragment_mask, linker_mask, edge_mask, context): | |
"""Samples x ~ p(x|z0). Samples only linker features and coords""" | |
zeros = torch.zeros(size=(z_0.size(0), 1), device=z_0.device) | |
gamma_0 = self.gamma(zeros) | |
# Computes sqrt(sigma_0^2 / alpha_0^2) | |
sigma_x = self.SNR(-0.5 * gamma_0).unsqueeze(1) | |
eps_hat = self.dynamics.forward( | |
t=zeros, | |
xh=z_0, | |
node_mask=node_mask, | |
linker_mask=linker_mask, | |
edge_mask=edge_mask, | |
context=context | |
) | |
eps_hat = eps_hat * linker_mask | |
mu_x = self.compute_x_pred(eps_t=eps_hat, z_t=z_0, gamma_t=gamma_0) | |
xh = self.sample_normal(mu=mu_x, sigma=sigma_x, node_mask=linker_mask) | |
xh = z_0 * fragment_mask + xh * linker_mask | |
x, h = xh[:, :, :self.n_dims], xh[:, :, self.n_dims:] | |
x, h = self.unnormalize(x, h) | |
h = F.one_hot(torch.argmax(h, dim=2), self.in_node_nf) * node_mask | |
return x, h | |
def compute_x_pred(self, eps_t, z_t, gamma_t): | |
"""Computes x_pred, i.e. the most likely prediction of x.""" | |
sigma_t = self.sigma(gamma_t, target_tensor=eps_t) | |
alpha_t = self.alpha(gamma_t, target_tensor=eps_t) | |
x_pred = 1. / alpha_t * (z_t - sigma_t * eps_t) | |
return x_pred | |
def kl_prior(self, xh, mask): | |
""" | |
Computes the KL between q(z1 | x) and the prior p(z1) = Normal(0, 1). | |
This is essentially a lot of work for something that is in practice negligible in the loss. | |
However, you compute it so that you see it when you've made a mistake in your noise schedule. | |
""" | |
# Compute the last alpha value, alpha_T | |
ones = torch.ones((xh.size(0), 1), device=xh.device) | |
gamma_T = self.gamma(ones) | |
alpha_T = self.alpha(gamma_T, xh) | |
# Compute means | |
mu_T = alpha_T * xh | |
mu_T_x, mu_T_h = mu_T[:, :, :self.n_dims], mu_T[:, :, self.n_dims:] | |
# Compute standard deviations (only batch axis for x-part, inflated for h-part) | |
sigma_T_x = self.sigma(gamma_T, mu_T_x).view(-1) # Remove inflate, only keep batch dimension for x-part | |
sigma_T_h = self.sigma(gamma_T, mu_T_h) | |
# Compute KL for h-part | |
zeros, ones = torch.zeros_like(mu_T_h), torch.ones_like(sigma_T_h) | |
kl_distance_h = self.gaussian_kl(mu_T_h, sigma_T_h, zeros, ones) | |
# Compute KL for x-part | |
zeros, ones = torch.zeros_like(mu_T_x), torch.ones_like(sigma_T_x) | |
d = self.dimensionality(mask) | |
kl_distance_x = self.gaussian_kl_for_dimension(mu_T_x, sigma_T_x, zeros, ones, d=d) | |
return kl_distance_x + kl_distance_h | |
def log_constant_of_p_x_given_z0(self, x, mask): | |
batch_size = x.size(0) | |
degrees_of_freedom_x = self.dimensionality(mask) | |
zeros = torch.zeros((batch_size, 1), device=x.device) | |
gamma_0 = self.gamma(zeros) | |
# Recall that sigma_x = sqrt(sigma_0^2 / alpha_0^2) = SNR(-0.5 gamma_0) | |
log_sigma_x = 0.5 * gamma_0.view(batch_size) | |
return degrees_of_freedom_x * (- log_sigma_x - 0.5 * np.log(2 * np.pi)) | |
def log_p_xh_given_z0_without_constants(self, h, z_0, gamma_0, eps, eps_hat, mask, epsilon=1e-10): | |
# Discrete properties are predicted directly from z_0 | |
z_h = z_0[:, :, self.n_dims:] | |
# Take only part over x | |
eps_x = eps[:, :, :self.n_dims] | |
eps_hat_x = eps_hat[:, :, :self.n_dims] | |
# Compute sigma_0 and rescale to the integer scale of the data | |
sigma_0 = self.sigma(gamma_0, target_tensor=z_0) * self.norm_values[1] | |
# Computes the error for the distribution N(x | 1 / alpha_0 z_0 + sigma_0/alpha_0 eps_0, sigma_0 / alpha_0), | |
# the weighting in the epsilon parametrization is exactly '1' | |
log_p_x_given_z_without_constants = -0.5 * self.sum_except_batch((eps_x - eps_hat_x) ** 2) | |
# Categorical features | |
# Compute delta indicator masks | |
h = h * self.norm_values[1] + self.norm_biases[1] | |
estimated_h = z_h * self.norm_values[1] + self.norm_biases[1] | |
# Centered h_cat around 1, since onehot encoded | |
centered_h = estimated_h - 1 | |
# Compute integrals from 0.5 to 1.5 of the normal distribution | |
# N(mean=centered_h_cat, stdev=sigma_0_cat) | |
log_p_h_proportional = torch.log( | |
self.cdf_standard_gaussian((centered_h + 0.5) / sigma_0) - | |
self.cdf_standard_gaussian((centered_h - 0.5) / sigma_0) + | |
epsilon | |
) | |
# Normalize the distribution over the categories | |
log_Z = torch.logsumexp(log_p_h_proportional, dim=2, keepdim=True) | |
log_probabilities = log_p_h_proportional - log_Z | |
# Select the log_prob of the current category using the onehot representation | |
log_p_h_given_z = self.sum_except_batch(log_probabilities * h * mask) | |
# Combine log probabilities for x and h | |
log_p_xh_given_z = log_p_x_given_z_without_constants + log_p_h_given_z | |
return log_p_xh_given_z | |
def sample_combined_position_feature_noise(self, n_samples, n_nodes, mask): | |
z_x = utils.sample_gaussian_with_mask( | |
size=(n_samples, n_nodes, self.n_dims), | |
device=mask.device, | |
node_mask=mask | |
) | |
z_h = utils.sample_gaussian_with_mask( | |
size=(n_samples, n_nodes, self.in_node_nf), | |
device=mask.device, | |
node_mask=mask | |
) | |
z = torch.cat([z_x, z_h], dim=2) | |
return z | |
def sample_normal(self, mu, sigma, node_mask): | |
"""Samples from a Normal distribution.""" | |
eps = self.sample_combined_position_feature_noise(mu.size(0), mu.size(1), node_mask) | |
return mu + sigma * eps | |
def normalize(self, x, h): | |
new_x = x / self.norm_values[0] | |
new_h = (h.float() - self.norm_biases[1]) / self.norm_values[1] | |
return new_x, new_h | |
def unnormalize(self, x, h): | |
new_x = x * self.norm_values[0] | |
new_h = h * self.norm_values[1] + self.norm_biases[1] | |
return new_x, new_h | |
def unnormalize_z(self, z): | |
assert z.size(2) == self.n_dims + self.in_node_nf | |
x, h = z[:, :, :self.n_dims], z[:, :, self.n_dims:] | |
x, h = self.unnormalize(x, h) | |
return torch.cat([x, h], dim=2) | |
def delta_log_px(self, mask): | |
return -self.dimensionality(mask) * np.log(self.norm_values[0]) | |
def dimensionality(self, mask): | |
return self.numbers_of_nodes(mask) * self.n_dims | |
def sigma(self, gamma, target_tensor): | |
"""Computes sigma given gamma.""" | |
return self.inflate_batch_array(torch.sqrt(torch.sigmoid(gamma)), target_tensor) | |
def alpha(self, gamma, target_tensor): | |
"""Computes alpha given gamma.""" | |
return self.inflate_batch_array(torch.sqrt(torch.sigmoid(-gamma)), target_tensor) | |
def SNR(self, gamma): | |
"""Computes signal to noise ratio (alpha^2/sigma^2) given gamma.""" | |
return torch.exp(-gamma) | |
def sigma_and_alpha_t_given_s(self, gamma_t: torch.Tensor, gamma_s: torch.Tensor, target_tensor: torch.Tensor): | |
""" | |
Computes sigma t given s, using gamma_t and gamma_s. Used during sampling. | |
These are defined as: | |
alpha t given s = alpha t / alpha s, | |
sigma t given s = sqrt(1 - (alpha t given s) ^2 ). | |
""" | |
sigma2_t_given_s = self.inflate_batch_array( | |
-self.expm1(self.softplus(gamma_s) - self.softplus(gamma_t)), | |
target_tensor | |
) | |
# alpha_t_given_s = alpha_t / alpha_s | |
log_alpha2_t = F.logsigmoid(-gamma_t) | |
log_alpha2_s = F.logsigmoid(-gamma_s) | |
log_alpha2_t_given_s = log_alpha2_t - log_alpha2_s | |
alpha_t_given_s = torch.exp(0.5 * log_alpha2_t_given_s) | |
alpha_t_given_s = self.inflate_batch_array(alpha_t_given_s, target_tensor) | |
sigma_t_given_s = torch.sqrt(sigma2_t_given_s) | |
return sigma2_t_given_s, sigma_t_given_s, alpha_t_given_s | |
def numbers_of_nodes(mask): | |
return torch.sum(mask.squeeze(2), dim=1) | |
def inflate_batch_array(array, target): | |
""" | |
Inflates the batch array (array) with only a single axis (i.e. shape = (batch_size,), | |
or possibly more empty axes (i.e. shape (batch_size, 1, ..., 1)) to match the target shape. | |
""" | |
target_shape = (array.size(0),) + (1,) * (len(target.size()) - 1) | |
return array.view(target_shape) | |
def sum_except_batch(x): | |
return x.view(x.size(0), -1).sum(-1) | |
def expm1(x: torch.Tensor) -> torch.Tensor: | |
return torch.expm1(x) | |
def softplus(x: torch.Tensor) -> torch.Tensor: | |
return F.softplus(x) | |
def cdf_standard_gaussian(x): | |
return 0.5 * (1. + torch.erf(x / math.sqrt(2))) | |
def gaussian_kl(q_mu, q_sigma, p_mu, p_sigma): | |
""" | |
Computes the KL distance between two normal distributions. | |
Args: | |
q_mu: Mean of distribution q. | |
q_sigma: Standard deviation of distribution q. | |
p_mu: Mean of distribution p. | |
p_sigma: Standard deviation of distribution p. | |
Returns: | |
The KL distance, summed over all dimensions except the batch dim. | |
""" | |
kl = torch.log(p_sigma / q_sigma) + 0.5 * (q_sigma ** 2 + (q_mu - p_mu) ** 2) / (p_sigma ** 2) - 0.5 | |
return EDM.sum_except_batch(kl) | |
def gaussian_kl_for_dimension(q_mu, q_sigma, p_mu, p_sigma, d): | |
""" | |
Computes the KL distance between two normal distributions taking the dimension into account. | |
Args: | |
q_mu: Mean of distribution q. | |
q_sigma: Standard deviation of distribution q. | |
p_mu: Mean of distribution p. | |
p_sigma: Standard deviation of distribution p. | |
d: dimension | |
Returns: | |
The KL distance, summed over all dimensions except the batch dim. | |
""" | |
mu_norm_2 = EDM.sum_except_batch((q_mu - p_mu) ** 2) | |
return d * torch.log(p_sigma / q_sigma) + 0.5 * (d * q_sigma ** 2 + mu_norm_2) / (p_sigma ** 2) - 0.5 * d | |
class InpaintingEDM(EDM): | |
def forward(self, x, h, node_mask, fragment_mask, linker_mask, edge_mask, context=None): | |
# Normalization and concatenation | |
x, h = self.normalize(x, h) | |
xh = torch.cat([x, h], dim=2) | |
# Volume change loss term | |
delta_log_px = self.delta_log_px(node_mask).mean() | |
# Sample t | |
t_int = torch.randint(0, self.T + 1, size=(x.size(0), 1), device=x.device).float() | |
s_int = t_int - 1 | |
t = t_int / self.T | |
s = s_int / self.T | |
# Masks for t=0 and t>0 | |
t_is_zero = (t_int == 0).squeeze().float() | |
t_is_not_zero = 1 - t_is_zero | |
# Compute gamma_t and gamma_s according to the noise schedule | |
gamma_t = self.inflate_batch_array(self.gamma(t), x) | |
gamma_s = self.inflate_batch_array(self.gamma(s), x) | |
# Compute alpha_t and sigma_t from gamma | |
alpha_t = self.alpha(gamma_t, x) | |
sigma_t = self.sigma(gamma_t, x) | |
# Sample noise | |
eps_t = self.sample_combined_position_feature_noise(n_samples=x.size(0), n_nodes=x.size(1), mask=node_mask) | |
# Sample z_t given x, h for timestep t, from q(z_t | x, h) | |
# Note: keep fragments unchanged | |
z_t = alpha_t * xh + sigma_t * eps_t | |
# Neural net prediction | |
eps_t_hat = self.dynamics.forward( | |
xh=z_t, | |
t=t, | |
node_mask=node_mask, | |
linker_mask=None, | |
context=context, | |
edge_mask=edge_mask, | |
) | |
# Computing basic error (further used for computing NLL and L2-loss) | |
error_t = self.sum_except_batch((eps_t - eps_t_hat) ** 2) | |
# Computing L2-loss for t>0 | |
normalization = (self.n_dims + self.in_node_nf) * self.numbers_of_nodes(node_mask) | |
l2_loss = error_t / normalization | |
l2_loss = l2_loss.mean() | |
# The KL between q(z_T | x) and p(z_T) = Normal(0, 1) (should be close to zero) | |
kl_prior = self.kl_prior(xh, node_mask).mean() | |
# Computing NLL middle term | |
SNR_weight = (self.SNR(gamma_s - gamma_t) - 1).squeeze(1).squeeze(1) | |
loss_term_t = self.T * 0.5 * SNR_weight * error_t | |
loss_term_t = (loss_term_t * t_is_not_zero).sum() / t_is_not_zero.sum() | |
# Computing noise returned by dynamics | |
noise = torch.norm(eps_t_hat, dim=[1, 2]) | |
noise_t = (noise * t_is_not_zero).sum() / t_is_not_zero.sum() | |
if t_is_zero.sum() > 0: | |
# The _constants_ depending on sigma_0 from the | |
# cross entropy term E_q(z0 | x) [log p(x | z0)] | |
neg_log_constants = -self.log_constant_of_p_x_given_z0(x, node_mask) | |
# Computes the L_0 term (even if gamma_t is not actually gamma_0) | |
# and selected only relevant via masking | |
loss_term_0 = -self.log_p_xh_given_z0_without_constants(h, z_t, gamma_t, eps_t, eps_t_hat, node_mask) | |
loss_term_0 = loss_term_0 + neg_log_constants | |
loss_term_0 = (loss_term_0 * t_is_zero).sum() / t_is_zero.sum() | |
# Computing noise returned by dynamics | |
noise_0 = (noise * t_is_zero).sum() / t_is_zero.sum() | |
else: | |
loss_term_0 = 0. | |
noise_0 = 0. | |
return delta_log_px, kl_prior, loss_term_t, loss_term_0, l2_loss, noise_t, noise_0 | |
def sample_chain(self, x, h, node_mask, edge_mask, fragment_mask, linker_mask, context, keep_frames=None): | |
n_samples = x.size(0) | |
n_nodes = x.size(1) | |
# Normalization and concatenation | |
x, h, = self.normalize(x, h) | |
xh = torch.cat([x, h], dim=2) | |
# Sampling initial noise | |
z = self.sample_combined_position_feature_noise(n_samples, n_nodes, node_mask) | |
if keep_frames is None: | |
keep_frames = self.T | |
else: | |
assert keep_frames <= self.T | |
chain = torch.zeros((keep_frames,) + z.size(), device=z.device) | |
# Sample p(z_s | z_t) | |
for s in tqdm(reversed(range(0, self.T)), total=self.T): | |
s_array = torch.full((n_samples, 1), fill_value=s, device=z.device) | |
t_array = s_array + 1 | |
s_array = s_array / self.T | |
t_array = t_array / self.T | |
z_linker_only_sampled = self.sample_p_zs_given_zt( | |
s=s_array, | |
t=t_array, | |
z_t=z, | |
node_mask=node_mask, | |
edge_mask=edge_mask, | |
context=context, | |
) | |
z_fragments_only_sampled = self.sample_q_zs_given_zt_and_x( | |
s=s_array, | |
t=t_array, | |
z_t=z, | |
x=xh * fragment_mask, | |
node_mask=fragment_mask, | |
) | |
z = z_linker_only_sampled * linker_mask + z_fragments_only_sampled * fragment_mask | |
# Project down to avoid numerical runaway of the center of gravity | |
z_x = utils.remove_mean_with_mask(z[:, :, :self.n_dims], node_mask) | |
z_h = z[:, :, self.n_dims:] | |
z = torch.cat([z_x, z_h], dim=2) | |
# Saving step to the chain | |
write_index = (s * keep_frames) // self.T | |
chain[write_index] = self.unnormalize_z(z) | |
# Finally sample p(x, h | z_0) | |
x_out_linker, h_out_linker = self.sample_p_xh_given_z0( | |
z_0=z, | |
node_mask=node_mask, | |
edge_mask=edge_mask, | |
context=context, | |
) | |
x_out_fragments, h_out_fragments = self.sample_q_xh_given_z0_and_x(z_0=z, node_mask=node_mask) | |
xh_out_linker = torch.cat([x_out_linker, h_out_linker], dim=2) | |
xh_out_fragments = torch.cat([x_out_fragments, h_out_fragments], dim=2) | |
xh_out = xh_out_linker * linker_mask + xh_out_fragments * fragment_mask | |
# Overwrite last frame with the resulting x and h | |
chain[0] = xh_out | |
return chain | |
def sample_p_zs_given_zt(self, s, t, z_t, node_mask, edge_mask, context): | |
"""Samples from zs ~ p(zs | zt). Only used during sampling""" | |
gamma_s = self.gamma(s) | |
gamma_t = self.gamma(t) | |
sigma2_t_given_s, sigma_t_given_s, alpha_t_given_s = self.sigma_and_alpha_t_given_s(gamma_t, gamma_s, z_t) | |
sigma_s = self.sigma(gamma_s, target_tensor=z_t) | |
sigma_t = self.sigma(gamma_t, target_tensor=z_t) | |
# Neural net prediction. | |
eps_hat = self.dynamics.forward( | |
xh=z_t, | |
t=t, | |
node_mask=node_mask, | |
linker_mask=None, | |
edge_mask=edge_mask, | |
context=context | |
) | |
# Checking that epsilon is centered around linker COM | |
utils.assert_mean_zero_with_mask(eps_hat[:, :, :self.n_dims], node_mask) | |
# Compute mu for p(z_s | z_t) | |
mu = z_t / alpha_t_given_s - (sigma2_t_given_s / alpha_t_given_s / sigma_t) * eps_hat | |
# Compute sigma for p(z_s | z_t) | |
sigma = sigma_t_given_s * sigma_s / sigma_t | |
# Sample z_s given the parameters derived from z_t | |
z_s = self.sample_normal(mu, sigma, node_mask) | |
return z_s | |
def sample_q_zs_given_zt_and_x(self, s, t, z_t, x, node_mask): | |
"""Samples from zs ~ q(zs | zt, x). Only used during sampling. Samples only linker features and coords""" | |
gamma_s = self.gamma(s) | |
gamma_t = self.gamma(t) | |
sigma2_t_given_s, sigma_t_given_s, alpha_t_given_s = self.sigma_and_alpha_t_given_s(gamma_t, gamma_s, z_t) | |
sigma_s = self.sigma(gamma_s, target_tensor=z_t) | |
sigma_t = self.sigma(gamma_t, target_tensor=z_t) | |
alpha_s = self.alpha(gamma_s, x) | |
mu = ( | |
alpha_t_given_s * (sigma_s ** 2) / (sigma_t ** 2) * z_t + | |
alpha_s * sigma2_t_given_s / (sigma_t ** 2) * x | |
) | |
# Compute sigma for p(zs | zt) | |
sigma = sigma_t_given_s * sigma_s / sigma_t | |
# Sample zs given the parameters derived from zt | |
z_s = self.sample_normal(mu, sigma, node_mask) | |
return z_s | |
def sample_p_xh_given_z0(self, z_0, node_mask, edge_mask, context): | |
"""Samples x ~ p(x|z0). Samples only linker features and coords""" | |
zeros = torch.zeros(size=(z_0.size(0), 1), device=z_0.device) | |
gamma_0 = self.gamma(zeros) | |
# Computes sqrt(sigma_0^2 / alpha_0^2) | |
sigma_x = self.SNR(-0.5 * gamma_0).unsqueeze(1) | |
eps_hat = self.dynamics.forward( | |
xh=z_0, | |
t=zeros, | |
node_mask=node_mask, | |
linker_mask=None, | |
edge_mask=edge_mask, | |
context=context | |
) | |
utils.assert_mean_zero_with_mask(eps_hat[:, :, :self.n_dims], node_mask) | |
mu_x = self.compute_x_pred(eps_hat, z_0, gamma_0) | |
xh = self.sample_normal(mu=mu_x, sigma=sigma_x, node_mask=node_mask) | |
x, h = xh[:, :, :self.n_dims], xh[:, :, self.n_dims:] | |
x, h = self.unnormalize(x, h) | |
h = F.one_hot(torch.argmax(h, dim=2), self.in_node_nf) * node_mask | |
return x, h | |
def sample_q_xh_given_z0_and_x(self, z_0, node_mask): | |
"""Samples x ~ q(x|z0). Samples only linker features and coords""" | |
zeros = torch.zeros(size=(z_0.size(0), 1), device=z_0.device) | |
gamma_0 = self.gamma(zeros) | |
alpha_0 = self.alpha(gamma_0, z_0) | |
sigma_0 = self.sigma(gamma_0, z_0) | |
eps = self.sample_combined_position_feature_noise(z_0.size(0), z_0.size(1), node_mask) | |
xh = (1 / alpha_0) * z_0 - (sigma_0 / alpha_0) * eps | |
x, h = xh[:, :, :self.n_dims], xh[:, :, self.n_dims:] | |
x, h = self.unnormalize(x, h) | |
h = F.one_hot(torch.argmax(h, dim=2), self.in_node_nf) * node_mask | |
return x, h | |
def sample_combined_position_feature_noise(self, n_samples, n_nodes, mask): | |
z_x = utils.sample_center_gravity_zero_gaussian_with_mask( | |
size=(n_samples, n_nodes, self.n_dims), | |
device=mask.device, | |
node_mask=mask | |
) | |
z_h = utils.sample_gaussian_with_mask( | |
size=(n_samples, n_nodes, self.in_node_nf), | |
device=mask.device, | |
node_mask=mask | |
) | |
z = torch.cat([z_x, z_h], dim=2) | |
return z | |
def dimensionality(self, mask): | |
return (self.numbers_of_nodes(mask) - 1) * self.n_dims | |