from PIL import Image import requests import torch from safetensors.torch import load_file, save_file from transformers import AutoModelForCausalLM, AutoProcessor def convert_file( pt_filename: str, sf_filename: str, to_removes: dict[str, list[str]], ): """Following https://huggingface.co/spaces/safetensors/convert/blob/main/convert.py#L181 Parameters ---------- pt_filename : str Pytorch .bin filename sf_filename : str safetensors filename to_removes : dict[str, list[str]] Key is the tensor weight to keep and its value is the list of tensor weights that are tied to the key. Those in the list will be removed from the state_dict before saving to safetensors format. They will be added to the safetensors' metadata field to be properly initialized upon reading in the safetensors format. Raises ------ RuntimeError If the tensors in the original state_dict don't match new safetensors' then error raised. """ loaded = torch.load(pt_filename, map_location="cpu", weights_only=True) if "state_dict" in loaded: loaded = loaded["state_dict"] metadata = {"format": "pt"} # Now remove tied tensors from loaded for kept_name, to_remove_group in to_removes.items(): for to_remove in to_remove_group: # Add removed tensor name to metadata. Needed when reading safetensor if to_remove not in metadata: metadata[to_remove] = kept_name del loaded[to_remove] # Force tensors to be contiguous loaded = {k: v.contiguous() for k, v in loaded.items()} # Write safetensors version to disk save_file(loaded, sf_filename, metadata=metadata) # Checks reloaded = load_file(sf_filename) for k in loaded: pt_tensor = loaded[k] sf_tensor = reloaded[k] if not torch.equal(pt_tensor, sf_tensor): raise RuntimeError(f"The output tensors do not match for key {k}") print(f"Safetensors version saved to {sf_filename} and passed check") def validate( prompt: str, processor, model_pt: torch.nn.Module, model_sf: torch.nn.Module, device, torch_dtype, ): url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true" image = Image.open(requests.get(url, stream=True).raw) inputs = processor(text=prompt, images=image, return_tensors="pt").to( device, torch_dtype ) parsed_answers = [] for model in [model_pt, model_sf]: generated_ids = model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, do_sample=False, num_beams=3, ) generated_text = processor.batch_decode( generated_ids, skip_special_tokens=False )[0] parsed_answer = processor.post_process_generation( generated_text, task=prompt, image_size=(image.width, image.height) ) parsed_answers.append(parsed_answer) assert ( parsed_answers[0] == parsed_answers[1] ), f"pt gave {parsed_answers[0]} but sf gave {parsed_answers[1]}" print(f"{prompt} successfully matched between pt & sf versions") print(parsed_answers[0]) def main(): pt_filename = "pytorch_model.bin" sf_filename = "model.safetensors" # Key is the tensor to keep (shared.weight) and the value is the list # of tensors in the _tied_weights listed in the # Florence2LanguageForConditionalGeneration to_removes = { "language_model.model.shared.weight": [ "language_model.model.encoder.embed_tokens.weight", "language_model.model.decoder.embed_tokens.weight", "language_model.lm_head.weight", ] } convert_file( pt_filename=pt_filename, sf_filename=sf_filename, to_removes=to_removes ) # Validate on test image for a few different prompts device = "cuda:0" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 processor = AutoProcessor.from_pretrained( "microsoft/Florence-2-large-ft", trust_remote_code=True ) print("\nLoading .bin file") model_pt = AutoModelForCausalLM.from_pretrained( "microsoft/Florence-2-large-ft", torch_dtype=torch_dtype, trust_remote_code=True, ).to(device) print("\nLoading safetensors file") model_sf = AutoModelForCausalLM.from_pretrained( "./", use_safetensors=True, torch_dtype=torch_dtype, trust_remote_code=True, ).to(device) for prompt in ["", "", ""]: validate(prompt, processor, model_pt, model_sf, device, torch_dtype) print("") if __name__ == "__main__": main()