--- library_name: transformers tags: - vision - cell-biology - dino pipeline_tag: image-feature-extraction model-index: - name: cellrepDINO results: [] --- # CellrepDINO Model This is a custom DINO model for extracting rich representations of cell microscopy in condensed vector/array form. The forward method of the cellrepDINO model gives embeddings that can be used for relevant downstream tasks like perturbation prediction, mechanism of action (MoA) classification, nuclei size shape estimation, etc. Simply train a basic linear or logistic model using the embeddings. ## Model Details - Architecture: DINOv2 - Default Model Size: Giant (1.1 B parameters) - Patch Size: 14 - Default image size: 1024 - Default resize size: 518 - Default center crop: 518 ## Setup Please create an environment and run `pip install torch transformers Pillow numpy pandas torchvision omegaconf` . Activate this new environment. ## Example Usage There are different types of embeddings of embeddings one can extract, we recommend the mean/median embeddings over the patch tokens or the class token embedding. ``` from transformers import AutoModel, AutoProcessor from PIL import Image import torch # Set up device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load model and processor model = AutoModel.from_pretrained("LPhilllips/cellrepDINO", trust_remote_code=True, weights_only=True) processor = AutoProcessor.from_pretrained("LPhilllips/cellrepDINO", trust_remote_code=True) # Move model to device model = model.to(device) model.eval() # For multiple images: image_paths = ["image1.png", "image2.png"] images = [Image.open(path).convert('RGB') for path in image_paths] # Process batch of images # if you want different rezise and centercrop sizes, please specificy the resize_size, centercrop_size parameters below batch_inputs = processor.preprocess(images=images, resize_size = 518, centercrop_size = 518, return_tensors="pt") # Move image tensors to device batch_inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch_inputs.items()} # Generate embeddings for batch with torch.no_grad(): batch_outputs = model(**batch_inputs) mean_embeddings = batch_outputs['mean_pooled'] median_embeddings = batch_outputs['median_pooled'] cls_embeddings = batch_outputs['cls_token'] ``` Script for generating embeddings en mass (requires a csv with an `ImagePath` column ): ``` from transformers import AutoModel, AutoProcessor from PIL import Image import torch import pandas as pd import numpy as np from tqdm import tqdm import warnings from pathlib import Path from src.bioimageembeddings.models.dinov2 import DINOv2Model def load_model_and_processor(model_name="LPhilllips/cellrepDINO"): """Load the model and processor, setting up the device.""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #model = DINOv2Model('dinov2-giant') model = AutoModel.from_pretrained(model_name, trust_remote_code=True) checkpoint_path = '/novo/projects/lwph/workspace/BaseDINO/output/giant/eval/training_312499/teacher_checkpoint.pth' state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True) # Process state dict - IMPORTANT: using model_type='nn' like in working code # Process state dict model.load_state_dict(state_dict, model_type='nn') processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) model = model.to(device) model.eval() return model, processor, device def process_batch(image_paths, model, processor, device, batch_size=32, resize_size=518, crop_size=518): """Process a batch of images and return their embeddings.""" # Load and preprocess images images = [] valid_indices = [] for idx, path in enumerate(image_paths): try: img = Image.open(path) images.append(img) valid_indices.append(idx) except Exception as e: warnings.warn(f"Could not load image {path}: {str(e)}") continue if not images: return None, [] # Process images batch_inputs = processor.preprocess( images=images, resize_size=resize_size, crop_size=crop_size, return_tensors="pt" ) # Move to device batch_inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch_inputs.items()} # Generate embeddings with torch.no_grad(): embeddings = model.model.forward_features(batch_inputs['pixel_values']) mean_embeddings = embeddings["x_norm_patchtokens"].mean(dim=1) mean_embeddings = mean_embeddings.cpu().numpy() return mean_embeddings, valid_indices def process_and_save_embeddings(csv_path, output_path, batch_size=32): """Process all images in batches and save results to a feather file.""" # Load model and processor model, processor, device = load_model_and_processor() # Read the CSV file df = pd.read_csv(csv_path) # Initialize lists to store results all_embeddings = [] valid_rows = [] # Process in batches for i in tqdm(range(0, len(df), batch_size)): batch_df = df.iloc[i:i + batch_size] # Process batch embeddings, valid_indices = process_batch( batch_df['ImagePath'].tolist(), model, processor, device, batch_size=batch_size ) if embeddings is not None: # Keep track of valid rows and their embeddings valid_batch_rows = batch_df.iloc[valid_indices] all_embeddings.append(embeddings) valid_rows.append(valid_batch_rows) # Combine all results if valid_rows: final_df = pd.concat(valid_rows, ignore_index=True) final_embeddings = np.concatenate(all_embeddings, axis=0) # Add embedding columns to the dataframe embedding_cols = [f'embedding_{i}' for i in range(final_embeddings.shape[1])] embedding_df = pd.DataFrame(final_embeddings, columns=embedding_cols) # Combine metadata with embeddings final_df = pd.concat([final_df, embedding_df], axis=1) # Save to feather final_df.to_feather(output_path) print(f"Successfully processed {len(final_df)} images") print(f"Results saved to {output_path}") else: print("No valid images were processed") # Example usage if __name__ == "__main__": csv_path = "csv/with/image/path/columns" output_path = "/your/output/folder/path" process_and_save_embeddings( csv_path=csv_path, output_path=output_path, batch_size=32 # Adjust based on your GPU memory ) ```