Edit model card

You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

CellrepDINO Model

This is a custom DINO model for extracting rich representations of cell microscopy in condensed vector/array form. The forward method of the cellrepDINO model gives embeddings that can be used for relevant downstream tasks like perturbation prediction, mechanism of action (MoA) classification, nuclei size shape estimation, etc. Simply train a basic linear or logistic model using the embeddings.

Model Details

  • Architecture: DINOv2
  • Default Model Size: Giant (1.1 B parameters)
  • Patch Size: 14
  • Default image size: 1024
  • Default resize size: 518
  • Default center crop: 518

Setup

Please create an environment and run pip install torch transformers Pillow numpy pandas torchvision omegaconf . Activate this new environment.

Example Usage

There are different types of embeddings of embeddings one can extract, we recommend the mean/median embeddings over the patch tokens or the class token embedding.

from transformers import AutoModel, AutoProcessor
from PIL import Image
import torch

# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model and processor
model = AutoModel.from_pretrained("LPhilllips/cellrepDINO", trust_remote_code=True, weights_only=True)
processor = AutoProcessor.from_pretrained("LPhilllips/cellrepDINO", trust_remote_code=True)

# Move model to device
model = model.to(device)
model.eval()

# For multiple images:
image_paths = ["image1.png", "image2.png"] 
images = [Image.open(path).convert('RGB') for path in image_paths]

# Process batch of images
# if you want different rezise and centercrop sizes, please specificy the resize_size, centercrop_size parameters below
batch_inputs = processor.preprocess(images=images, resize_size = 518, centercrop_size = 518, return_tensors="pt")

# Move image tensors to device
batch_inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch_inputs.items()}

# Generate embeddings for batch
with torch.no_grad():
    batch_outputs = model(**batch_inputs)
    mean_embeddings = batch_outputs['mean_pooled']
    median_embeddings = batch_outputs['median_pooled']
    cls_embeddings = batch_outputs['cls_token']

Script for generating embeddings en mass (requires a csv with an ImagePath column ):

from transformers import AutoModel, AutoProcessor
from PIL import Image
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
import warnings
from pathlib import Path
from src.bioimageembeddings.models.dinov2 import DINOv2Model


def load_model_and_processor(model_name="LPhilllips/cellrepDINO"):
    """Load the model and processor, setting up the device."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    #model = DINOv2Model('dinov2-giant')

    model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
    checkpoint_path = '/novo/projects/lwph/workspace/BaseDINO/output/giant/eval/training_312499/teacher_checkpoint.pth'
    state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)

    # Process state dict - IMPORTANT: using model_type='nn' like in working code
    # Process state dict
    model.load_state_dict(state_dict, model_type='nn')
    processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
    
    model = model.to(device)
    model.eval()
    
    return model, processor, device

def process_batch(image_paths, model, processor, device, batch_size=32, resize_size=518, crop_size=518):
    """Process a batch of images and return their embeddings."""
    # Load and preprocess images
    images = []
    valid_indices = []
    
    for idx, path in enumerate(image_paths):
        try:
            img = Image.open(path)
            images.append(img)
            valid_indices.append(idx)
        except Exception as e:
            warnings.warn(f"Could not load image {path}: {str(e)}")
            continue

    if not images:
        return None, []

    # Process images
    batch_inputs = processor.preprocess(
        images=images,
        resize_size=resize_size,
        crop_size=crop_size,
        return_tensors="pt"
    )

    # Move to device
    batch_inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v 
                   for k, v in batch_inputs.items()}

    # Generate embeddings
    with torch.no_grad():
        embeddings = model.model.forward_features(batch_inputs['pixel_values'])
        mean_embeddings = embeddings["x_norm_patchtokens"].mean(dim=1)
        mean_embeddings = mean_embeddings.cpu().numpy()

    return mean_embeddings, valid_indices

def process_and_save_embeddings(csv_path, output_path, batch_size=32):
    """Process all images in batches and save results to a feather file."""
    # Load model and processor
    model, processor, device = load_model_and_processor()
    
    # Read the CSV file
    df = pd.read_csv(csv_path)
    
    # Initialize lists to store results
    all_embeddings = []
    valid_rows = []
    
    # Process in batches
    for i in tqdm(range(0, len(df), batch_size)):
        batch_df = df.iloc[i:i + batch_size]
        
        # Process batch
        embeddings, valid_indices = process_batch(
            batch_df['ImagePath'].tolist(),
            model, processor, device,
            batch_size=batch_size
        )
        
        if embeddings is not None:
            # Keep track of valid rows and their embeddings
            valid_batch_rows = batch_df.iloc[valid_indices]
            all_embeddings.append(embeddings)
            valid_rows.append(valid_batch_rows)
    
    # Combine all results
    if valid_rows:
        final_df = pd.concat(valid_rows, ignore_index=True)
        final_embeddings = np.concatenate(all_embeddings, axis=0)
        
        # Add embedding columns to the dataframe
        embedding_cols = [f'embedding_{i}' for i in range(final_embeddings.shape[1])]
        embedding_df = pd.DataFrame(final_embeddings, columns=embedding_cols)
        
        # Combine metadata with embeddings
        final_df = pd.concat([final_df, embedding_df], axis=1)
        
        # Save to feather
        final_df.to_feather(output_path)
        
        print(f"Successfully processed {len(final_df)} images")
        print(f"Results saved to {output_path}")
    else:
        print("No valid images were processed")

# Example usage
if __name__ == "__main__":
    csv_path = "csv/with/image/path/columns"
    output_path = "/your/output/folder/path"
    
    process_and_save_embeddings(
        csv_path=csv_path,
        output_path=output_path,
        batch_size=32  # Adjust based on your GPU memory
    )
Downloads last month
9,665
Inference API
Inference API (serverless) does not yet support transformers models for this pipeline type.