jrno's picture
clean-up
b571090
raw
history blame contribute delete
948 Bytes
from fastai.collab import load_learner
from fastai.tabular.all import *
def create_params(size):
return nn.Parameter(torch.zeros(*size).normal_(0, 0.01))
class DotProductBias(Module):
def __init__(self, n_users, n_items, n_factors, y_range=(0, 1.5)):
super().__init__()
self.user_factors = create_params([n_users, n_factors])
self.user_bias = create_params([n_users])
self.item_factors = create_params([n_items, n_factors])
self.item_bias = create_params([n_items])
self.y_range = y_range
def forward(self, x):
users = self.user_factors[x[:, 0]]
items = self.item_factors[x[:, 1]]
res = (users * items).sum(dim=1)
res += self.user_bias[x[:, 0]] + self.item_bias[x[:, 1]]
return sigmoid_range(res, *self.y_range)
async def setup_learner(model_filename: str):
learn = load_learner(model_filename)
learn.dls.device = 'cpu'
return learn