colt12 commited on
Commit
08de8c9
·
verified ·
1 Parent(s): 5a8c48c

Create pytorch_model.bin

Browse files
Files changed (1) hide show
  1. pytorch_model.bin +54 -0
pytorch_model.bin ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python
2
+ import os
3
+ import torch
4
+ from safetensors import safe_open
5
+ from safetensors.torch import save_file
6
+ import logging
7
+ import shutil
8
+ from datetime import datetime
9
+
10
+ # Set up logging
11
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
12
+
13
+ def backup_existing_file(file_path):
14
+ if os.path.exists(file_path):
15
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
16
+ backup_path = f"{file_path}.backup_{timestamp}"
17
+ shutil.copy2(file_path, backup_path)
18
+ logging.info(f"Created backup of existing file: {backup_path}")
19
+
20
+ def convert_safetensors_to_pytorch(input_file, output_file):
21
+ try:
22
+ # Check if input file exists
23
+ if not os.path.exists(input_file):
24
+ raise FileNotFoundError(f"Input file {input_file} not found.")
25
+
26
+ # Backup existing pytorch_model.bin if it exists
27
+ backup_existing_file(output_file)
28
+
29
+ # Load the safetensors file
30
+ logging.info(f"Loading safetensors file: {input_file}")
31
+ with safe_open(input_file, framework="pt", device="cpu") as f:
32
+ state_dict = {key: f.get_tensor(key) for key in f.keys()}
33
+
34
+ # Save as PyTorch bin file
35
+ logging.info(f"Saving as PyTorch bin file: {output_file}")
36
+ torch.save(state_dict, output_file)
37
+
38
+ logging.info("Conversion complete.")
39
+ logging.info(f"Created: {output_file}")
40
+
41
+ except Exception as e:
42
+ logging.error(f"An error occurred during conversion: {str(e)}")
43
+ raise
44
+
45
+ if __name__ == "__main__":
46
+ input_file = "maxcushion.safetensors"
47
+ output_file = "pytorch_model.bin"
48
+
49
+ try:
50
+ convert_safetensors_to_pytorch(input_file, output_file)
51
+ except Exception as e:
52
+ logging.error(f"Conversion failed: {str(e)}")
53
+ else:
54
+ logging.info("Script executed successfully.")