Model description
This is a LogisticRegressionCV model trained on averages of patch embeddings from the Imagenette dataset. This forms the GAM of an Emb-GAM extended to images. Patch embeddings are meant to be extracted with the facebook/dino-vitb16
DINO checkpoint.
Intended uses & limitations
This model is not intended to be used in production.
Training Procedure
Hyperparameters
The model is trained with below hyperparameters.
Click to expand
Hyperparameter | Value |
---|---|
Cs | 10 |
class_weight | |
cv | StratifiedKFold(n_splits=5, random_state=1, shuffle=True) |
dual | False |
fit_intercept | True |
intercept_scaling | 1.0 |
l1_ratios | |
max_iter | 100 |
multi_class | auto |
n_jobs | |
penalty | l2 |
random_state | 1 |
refit | False |
scoring | |
solver | lbfgs |
tol | 0.0001 |
verbose | 0 |
Model Plot
The model plot is below.
LogisticRegressionCV(cv=StratifiedKFold(n_splits=5, random_state=1, shuffle=True),random_state=1, refit=False)Please rerun this cell to show the HTML repr or trust the notebook.
LogisticRegressionCV(cv=StratifiedKFold(n_splits=5, random_state=1, shuffle=True),random_state=1, refit=False)
Evaluation Results
You can find the details about evaluation process and the evaluation results.
Metric | Value |
---|---|
accuracy | 0.97707 |
f1 score | 0.97707 |
How to Get Started with the Model
Use the code below to get started with the model.
Click to expand
from PIL import Image
from skops import hub_utils
import torch
from transformers import AutoFeatureExtractor, AutoModel
import pickle
import os
# load embedding model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/dino-vitb16')
model = AutoModel.from_pretrained('facebook/dino-vitb16').eval().to(device)
# load logistic regression
os.mkdir('emb-gam-dino')
hub_utils.download(repo_id='Ramos-Ramos/emb-gam-dino', dst='emb-gam-dino')
with open('emb-gam-dino/model.pkl', 'rb') as file:
logistic_regression = pickle.load(file)
# load image
img = Image.open('examples/english_springer.png')
# preprocess image
inputs = {k: v.to(device) for k, v in feature_extractor(img, return_tensors='pt').items()}
# extract patch embeddings
with torch.no_grad():
patch_embeddings = model(**inputs).last_hidden_state[0, 1:].cpu()
# classify
pred = logistic_regression.predict(patch_embeddings.sum(dim=0, keepdim=True))
# get patch contributions
patch_contributions = logistic_regression.coef_ @ patch_embeddings.T.numpy()
Model Card Authors
This model card is written by following authors:
Patrick Ramos
Model Card Contact
You can contact the model card authors through following channels: [More Information Needed]
Citation
BibTeX:
@article{singh2022emb,
title={Emb-GAM: an Interpretable and Efficient Predictor using Pre-trained Language Models},
author={Singh, Chandan and Gao, Jianfeng},
journal={arXiv preprint arXiv:2209.11799},
year={2022}
}
Additional Content
confusion_matrix
Demo
Check out our HuggingFace Space here! It does Imagenette classification and visualizes patch contributions per label.
- Downloads last month
- 0