File size: 7,339 Bytes
f50f696 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import torch
from torch import nn
class BarDistribution(nn.Module):
def __init__(self, borders: torch.Tensor): # here borders should start with min and end with max, where all values lie in (min,max) and are sorted
# sorted list of borders
super().__init__()
assert len(borders.shape) == 1
#self.borders = borders
self.register_buffer('borders', borders)
#self.bucket_widths = self.borders[1:] - self.borders[:-1]
self.register_buffer('bucket_widths', self.borders[1:] - self.borders[:-1])
full_width = self.bucket_widths.sum()
assert (full_width - (self.borders[-1] - self.borders[0])).abs() < 1e-4, f'diff: {full_width - (self.borders[-1] - self.borders[0])}'
assert (torch.argsort(borders) == torch.arange(len(borders))).all(), "Please provide sorted borders!"
self.num_bars = len(borders) - 1
def map_to_bucket_idx(self, y):
target_sample = torch.searchsorted(self.borders, y) - 1
target_sample[y == self.borders[0]] = 0
target_sample[y == self.borders[-1]] = self.num_bars - 1
return target_sample
def forward(self, logits, y): # gives the negative log density (the _loss_), y: T x B, logits: T x B x self.num_bars
target_sample = self.map_to_bucket_idx(y)
assert (target_sample >= 0).all() and (target_sample < self.num_bars).all(), f'y {y} not in support set for borders (min_y, max_y) {self.borders}'
assert logits.shape[-1] == self.num_bars, f'{logits.shape[-1]} vs {self.num_bars}'
bucket_log_probs = torch.log_softmax(logits, -1)
scaled_bucket_log_probs = bucket_log_probs - torch.log(self.bucket_widths)
return -scaled_bucket_log_probs.gather(-1,target_sample.unsqueeze(-1)).squeeze(-1)
def mean(self, logits):
bucket_means = self.borders[:-1] + self.bucket_widths/2
p = torch.softmax(logits, -1)
return p @ bucket_means
def quantile(self, logits, center_prob=.682):
logits_shape = logits.shape
logits = logits.view(-1, logits.shape[-1])
side_prob = (1-center_prob)/2
probs = logits.softmax(-1)
flipped_probs = probs.flip(-1)
cumprobs = torch.cumsum(probs, -1)
flipped_cumprobs = torch.cumsum(flipped_probs, -1)
def find_lower_quantile(probs, cumprobs, side_prob, borders):
idx = (torch.searchsorted(cumprobs, side_prob)).clamp(0, len(cumprobs)-1) # this might not do the right for outliers
left_prob = cumprobs[idx-1]
rest_prob = side_prob - left_prob
left_border, right_border = borders[idx:idx+2]
return left_border + (right_border-left_border)*rest_prob/probs[idx]
results = []
for p,cp,f_p,f_cp in zip(probs, cumprobs, flipped_probs, flipped_cumprobs):
r = find_lower_quantile(p, cp, side_prob, self.borders), find_lower_quantile(f_p, f_cp, side_prob, self.borders.flip(0))
results.append(r)
return torch.tensor(results).reshape(*logits_shape[:-1],2)
def mode(self, logits):
mode_inds = logits.argmax(-1)
bucket_means = self.borders[:-1] + self.bucket_widths/2
return bucket_means[mode_inds]
def ei(self, logits, best_f, maximize=True): # logits: evaluation_points x batch x feature_dim
bucket_means = self.borders[:-1] + self.bucket_widths/2
if maximize:
bucket_contributions = torch.tensor(
[max((bucket_max + max(bucket_min, best_f)) / 2 - best_f,0) for
bucket_min, bucket_max, bucket_mean in zip(self.borders[:-1], self.borders[1:], bucket_means)], dtype=logits.dtype, device=logits.device)
else:
bucket_contributions = torch.tensor(
[-min((min(bucket_max,best_f) + bucket_min) / 2 - best_f,0) for # min on max instead of max on min, and compare min < instead of max >
bucket_min, bucket_max, bucket_mean in zip(self.borders[:-1], self.borders[1:], bucket_means)], dtype=logits.dtype, device=logits.device)
p = torch.softmax(logits, -1)
return p @ bucket_contributions
class FullSupportBarDistribution(BarDistribution):
@staticmethod
def halfnormal_with_p_weight_before(range_max,p=.5):
s = range_max / torch.distributions.HalfNormal(torch.tensor(1.)).icdf(torch.tensor(p))
return torch.distributions.HalfNormal(s)
def forward(self, logits, y): # gives the negative log density (the _loss_), y: T x B, logits: T x B x self.num_bars
assert self.num_bars > 1
target_sample = self.map_to_bucket_idx(y)
target_sample.clamp_(0,self.num_bars-1)
assert logits.shape[-1] == self.num_bars
bucket_log_probs = torch.log_softmax(logits, -1)
scaled_bucket_log_probs = bucket_log_probs - torch.log(self.bucket_widths)
#print(bucket_log_probs, logits.shape)
log_probs = scaled_bucket_log_probs.gather(-1,target_sample.unsqueeze(-1)).squeeze(-1)
side_normals = (self.halfnormal_with_p_weight_before(self.bucket_widths[0]), self.halfnormal_with_p_weight_before(self.bucket_widths[-1]))
# TODO look over it again
log_probs[target_sample == 0] += side_normals[0].log_prob((self.borders[1]-y[target_sample == 0]).clamp(min=.00000001)) + torch.log(self.bucket_widths[0])
log_probs[target_sample == self.num_bars-1] += side_normals[1].log_prob(y[target_sample == self.num_bars-1]-self.borders[-2]) + torch.log(self.bucket_widths[-1])
return -log_probs
def mean(self, logits):
bucket_means = self.borders[:-1] + self.bucket_widths / 2
p = torch.softmax(logits, -1)
side_normals = (self.halfnormal_with_p_weight_before(self.bucket_widths[0]),
self.halfnormal_with_p_weight_before(self.bucket_widths[-1]))
bucket_means[0] = -side_normals[0].mean + self.borders[1]
bucket_means[-1] = side_normals[1].mean + self.borders[-2]
return p @ bucket_means
def get_bucket_limits(num_outputs:int, full_range:tuple=None, ys:torch.Tensor=None):
assert (ys is not None) or (full_range is not None)
if ys is not None:
ys = ys.flatten()
if len(ys) % num_outputs: ys = ys[:-(len(ys) % num_outputs)]
print(f'Using {len(ys)} y evals to estimate {num_outputs} buckets. Cut off the last {len(ys) % num_outputs} ys.')
ys_per_bucket = len(ys) // num_outputs
if full_range is None:
full_range = (ys.min(), ys.max())
else:
assert full_range[0] <= ys.min() and full_range[1] >= ys.max()
full_range = torch.tensor(full_range)
ys_sorted, ys_order = ys.sort(0)
bucket_limits = (ys_sorted[ys_per_bucket-1::ys_per_bucket][:-1]+ys_sorted[ys_per_bucket::ys_per_bucket])/2
print(full_range)
bucket_limits = torch.cat([full_range[0].unsqueeze(0), bucket_limits, full_range[1].unsqueeze(0)],0)
else:
class_width = (full_range[1] - full_range[0]) / num_outputs
bucket_limits = torch.cat([full_range[0] + torch.arange(num_outputs).float()*class_width, torch.tensor(full_range[1]).unsqueeze(0)], 0)
assert len(bucket_limits) - 1 == num_outputs and full_range[0] == bucket_limits[0] and full_range[-1] == bucket_limits[-1]
return bucket_limits
|