--- library_name: transformers tags: - vision - cell-biology - dino pipeline_tag: image-feature-extraction model-index: - name: cellrepDINO results: [] --- # CellrepDINO Model This is a custom DINO model for extracting rich representations of cell microscopy in condensed vector/array form. The forward method of the cellrepDINO model gives embeddings that can be used for relevant downstream tasks like perturbation prediction, mechanism of action (MoA) classification, nuclei size shape estimation, etc. Simply train a basic linear or logistic model using the embeddings. ## Model Details - Architecture: DINOv2 - Type: Vision Transformer - Input Size: 518x518 - Patch Size: 14 - Image Size: 1024 - Center Crop: 518 ## Setup Please git clone the repository via `git clone --filter=blob:none https://huggingface.co/lhphillips/cellrepDINO`. Then `cd` to the directory, and run `pip install -e .` ## Example Usage There are different types of embeddings of embeddings one can extract, we recommend the mean/median embeddings over the patch tokens or the class token embedding. The code below is an example of the mean of the patch token embeddings. To get the median simply replace `batch_outputs['x_norm_patchtokens'].mean(dim=1)` with `batch_outputs['x_norm_patchtokens'].median(dim=1)`. To get the class token embeddings: `batch_embeddings = batch_outputs['x_norm_clstoken']['x_norm_clstoken']`. ``` from transformers import AutoModel, AutoProcessor from PIL import Image import torch # Set up device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load model and processor model = AutoModel.from_pretrained("LPhilllips/cellrepDINO", trust_remote_code=True) processor = AutoProcessor.from_pretrained("LPhilllips/cellrepDINO", trust_remote_code=True) # Move model to device model = model.to(device) model.eval() # For multiple images: image_paths = ["image1.png", "image2.png"] images = [Image.open(path) for path in image_paths] # Process batch of images batch_inputs = processor.preprocess(images=images, return_tensors="pt") batch_inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch_inputs.items()} # Generate embeddings for batch with torch.no_grad(): batch_outputs = model(**batch_inputs) batch_embeddings = batch_outputs['x_norm_patchtokens'].mean(dim=1) ```