|
""" |
|
An example of how to use this: |
|
x ,y , y_target = priors.fast_gp.get_batch(1,100,num_features, hyperparameters=(1e-4,1.,.6), equidistant_x=True) |
|
fig, ax = pyplot.subplots(figsize=[10,10]) |
|
plot_model_and_orig_curve(ax, SOME_MODEL, x, y, given_indices[10,40,60]) |
|
|
|
Don't worry it is normal to be slow... |
|
""" |
|
import matplotlib.patches as patches |
|
import seaborn as sns |
|
import torch |
|
|
|
|
|
def add_rect(ax, coord, height, width, color): |
|
rect = patches.Rectangle(coord, height, width, linewidth=1, edgecolor='none', facecolor=color) |
|
|
|
|
|
ax.add_patch(rect) |
|
|
|
|
|
def heatmap_with_box_sizes(ax, data: torch.Tensor, x_starts, x_ends, y_starts, y_ends, |
|
palette=sns.color_palette("rocket", as_cmap=True), set_lims=True): |
|
""" |
|
Beware all x and y arrays should be sorted from small to large and the data will appear in that same order: Small indexes map to lower x/y-axis values. |
|
""" |
|
if set_lims: |
|
ax.set_xlim(x_starts[0], x_ends[-1]) |
|
ax.set_ylim(y_starts[0], y_ends[-1]) |
|
|
|
data = (data - data.min()) / (data.max() - data.min()) |
|
|
|
for col_i, (col_start, col_end) in enumerate(zip(x_starts, x_ends)): |
|
for row_i, (row_start, row_end) in enumerate(zip(y_starts, y_ends)): |
|
add_rect(ax, (col_start, row_start), col_end - col_start, row_end - row_start, |
|
palette(data[row_i, col_i].item())) |
|
|
|
|
|
print(ax.get_ylim()) |
|
|
|
|
|
def plot_bar_distribution(ax, x: torch.Tensor, bar_borders: torch.Tensor, predictions: torch.Tensor, **kwargs): |
|
x = x.squeeze() |
|
predictions = predictions.squeeze() |
|
assert len(x.shape) == 1 and len(predictions.shape) == 2 and len(predictions) == len(x) and len( |
|
bar_borders.shape) == 1 and len(bar_borders) - 1 == predictions.shape[1] |
|
|
|
y_starts = bar_borders[:-1] |
|
y_ends = bar_borders[1:] |
|
|
|
x, order = x.sort(0) |
|
print(x.shape, predictions.shape, order.shape) |
|
|
|
predictions = predictions[order] / (bar_borders[1:] - bar_borders[:-1]) |
|
print(predictions.shape) |
|
|
|
|
|
x_starts = torch.cat([x[0].unsqueeze(0), (x[1:] + x[:-1]) / 2]) |
|
x_ends = torch.cat([(x[1:] + x[:-1]) / 2, x[-1].unsqueeze(0), ]) |
|
|
|
heatmap_with_box_sizes(ax, predictions.T, x_starts, x_ends, y_starts, y_ends, **kwargs) |
|
|
|
|
|
def plot_model_w_eval_pos(ax, model, x, y, single_eval_pos, softmax=False, min_max_y=None, **kwargs): |
|
with torch.no_grad(): |
|
model.eval() |
|
y_pred = model((x, y), single_eval_pos=single_eval_pos) |
|
if softmax: |
|
y_pred = y_pred.softmax(-1) |
|
if min_max_y: |
|
lowest_bar = torch.searchsorted(model.criterion.borders, min_max_y[0]) |
|
highest_bar = min(torch.searchsorted(model.criterion.borders, min_max_y[1]), len(model.criterion.borders)) |
|
borders = model.criterion.borders[lowest_bar:highest_bar] |
|
y_pred = y_pred[..., lowest_bar:highest_bar - 1] |
|
else: |
|
borders = model.criterion.borders |
|
plot_bar_distribution(ax, x[single_eval_pos:], borders, y_pred, **kwargs) |
|
|
|
|
|
def plot_model_and_orig_curve(ax, model, x, y, given_indices=[0]): |
|
""" |
|
:param ax: A standard pyplot ax |
|
:param model: A Transformer Model with `single_eval_pos` |
|
:param x: A three-dimensional input tensor with x.shape[0]=1 and x.shape[2]=1 |
|
:param y: A two-dimensional tensor with y.shape[1]=0 |
|
:param given_indices: The indexes in y which should be given to the model (the training points) |
|
:return: |
|
""" |
|
x_winput = torch.cat([x[given_indices], x], 0) |
|
y_winput = torch.cat([y[given_indices], y], 0) |
|
|
|
ax.plot(x.squeeze(), y.squeeze(), color='grey') |
|
ax.plot(x.squeeze()[given_indices], y.squeeze()[given_indices], 'o', color='black') |
|
plot_model_w_eval_pos(ax, model, x_winput, y_winput, len(given_indices), |
|
min_max_y=(y.min() - .3, y.max() + .3), softmax=True, |
|
palette=sns.cubehelix_palette(start=2, rot=0, dark=0.4, light=1, as_cmap=True)) |
|
|
|
|
|
|