jrno's picture
Try out new version
aad08d3
raw
history blame contribute delete
No virus
1.21 kB
from fastai.collab import load_learner
from fastai.tabular.all import *
class DotProductBias(Module):
def __init__(self, n_users, n_movies, n_factors, y_range=(0,1.1)):
self.user_factors = Embedding(n_users, n_factors)
self.user_bias = Embedding(n_users, 1)
self.movie_factors = Embedding(n_movies, n_factors)
self.movie_bias = Embedding(n_movies, 1)
self.y_range = y_range
def forward(self, x):
users = self.user_factors(x[:,0])
movies = self.movie_factors(x[:,1])
res = (users * movies).sum(dim=1, keepdim=True)
res += self.user_bias(x[:,0]) + self.movie_bias(x[:,1])
return sigmoid_range(res, *self.y_range)
def custom_accuracy(prediction, target):
# set all predictions above 0.95 as true positive (correct prediction)
prediction = torch.where(prediction > 0.95, torch.tensor(1.0), prediction)
# shape [64, 1] to [64]
target = target.squeeze(1)
correct = (prediction == target).float()
accuracy = correct.sum() / len(target)
return accuracy
async def setup_learner(model_filename: str):
learn = load_learner(model_filename)
learn.dls.device = 'cpu'
return learn