kjerk commited on
Commit
d4615af
1 Parent(s): ce883f8

Fix embedding reparser regression

Browse files
Files changed (1) hide show
  1. tools/torch_tools.py +18 -5
tools/torch_tools.py CHANGED
@@ -16,16 +16,29 @@ def get_target_dtype_ref(target_dtype: str) -> torch.dtype:
16
  raise ValueError(f"Invalid target_dtype: {target_dtype}")
17
 
18
  def convert_ckpt_to_safetensors(ckpt_upload: io.BytesIO, target_dtype) -> dict:
 
 
 
19
  target_dtype = get_target_dtype_ref(target_dtype)
20
- ckpt_data = ckpt_upload.getvalue()
21
 
22
  # Load the checkpoint
23
- checkpoint = torch.load(ckpt_data, map_location="cpu")
24
 
25
- # Convert the checkpoint to a dictionary of tensors
26
  tensor_dict = {}
27
- for key, val in checkpoint.items():
28
- tensor_dict[key] = val.to(dtype=target_dtype)
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  return tensor_dict
31
 
 
16
  raise ValueError(f"Invalid target_dtype: {target_dtype}")
17
 
18
  def convert_ckpt_to_safetensors(ckpt_upload: io.BytesIO, target_dtype) -> dict:
19
+ if isinstance(ckpt_upload, bytes):
20
+ ckpt_upload = io.BytesIO(ckpt_upload)
21
+
22
  target_dtype = get_target_dtype_ref(target_dtype)
 
23
 
24
  # Load the checkpoint
25
+ loaded_dict = torch.load(ckpt_upload, map_location="cpu")
26
 
 
27
  tensor_dict = {}
28
+
29
+ is_embedding = 'string_to_param' in loaded_dict
30
+ if is_embedding:
31
+ emb_tensor = loaded_dict.get('string_to_param', {}).get('*', None)
32
+ if emb_tensor is not None:
33
+ emb_tensor = emb_tensor.to(dtype=target_dtype)
34
+ tensor_dict = {
35
+ 'emb_params': emb_tensor
36
+ }
37
+ else:
38
+ # Convert weights in a checkpoint to a dictionary of tensors
39
+ for key, val in loaded_dict.items():
40
+ if isinstance(val, torch.Tensor):
41
+ tensor_dict[key] = val.to(dtype=target_dtype)
42
 
43
  return tensor_dict
44