jrno's picture
refresh model
e2a6226
raw
history blame
No virus
587 Bytes
from fastai.collab import load_learner
from fastai.tabular.all import *
def custom_accuracy(prediction, target):
# set all predictions above 0.95 as true positive (correct prediction)
prediction = torch.where(prediction > 0.95, torch.tensor(1.0), prediction)
# shape [64, 1] to [64]
target = target.squeeze(1)
correct = (prediction == target).float()
accuracy = correct.sum() / len(target)
return accuracy
async def setup_learner(model_filename: str):
learn = load_learner(model_filename)
learn.dls.device = 'cpu'
return learn