song-recommender / custom_models.py
jrno's picture
Add dockerfile
f6b6982
raw
history blame
No virus
779 Bytes
from fastai.tabular.all import *
def create_params(size):
return nn.Parameter(torch.zeros(*size).normal_(0, 0.01))
class DotProductBias(Module):
def __init__(self, n_users, n_items, n_factors, y_range=(0, 1.5)):
super().__init__()
self.user_factors = create_params([n_users, n_factors])
self.user_bias = create_params([n_users])
self.item_factors = create_params([n_items, n_factors])
self.item_bias = create_params([n_items])
self.y_range = y_range
def forward(self, x):
users = self.user_factors[x[:,0]]
items = self.item_factors[x[:,1]]
res = (users * items).sum(dim=1)
res += self.user_bias[x[:,0]] + self.item_bias[x[:,1]]
return sigmoid_range(res, *self.y_range)