ixxan commited on
Commit
d07e27a
1 Parent(s): 8f0494a

from pre trained

Browse files
Files changed (1) hide show
  1. UModel.py +29 -0
UModel.py CHANGED
@@ -153,6 +153,35 @@ class UModel(nn.Module):
153
 
154
  predstrs = [uyghur_latin.decode(pred) for pred in preds]
155
  return predstrs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
 
158
  if __name__ == "__main__":
 
153
 
154
  predstrs = [uyghur_latin.decode(pred) for pred in preds]
155
  return predstrs
156
+
157
+ @classmethod
158
+ def from_pretrained(cls, repo_id: str, filename: str = "best_model.pth", device: str = "cpu", **kwargs):
159
+ """
160
+ Load a pretrained UModel from a Hugging Face repository.
161
+
162
+ Args:
163
+ repo_id (str): The Hugging Face repository ID.
164
+ filename (str): The name of the checkpoint file in the repo.
165
+ device (str): The device to load the model onto ("cpu" or "cuda").
166
+ **kwargs: Additional arguments to pass to the UModel constructor.
167
+
168
+ Returns:
169
+ UModel: An instance of UModel with pretrained weights loaded.
170
+ """
171
+ # Download the checkpoint from the Hugging Face Hub
172
+ checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
173
+
174
+ # Initialize the model
175
+ model = cls(**kwargs)
176
+
177
+ # Load state dict
178
+ checkpoint = torch.load(checkpoint_path, map_location=device)
179
+ model.load_state_dict(checkpoint["st_dict"])
180
+
181
+ # Set model to evaluation mode
182
+ model.eval()
183
+ print(f"Loaded model from {repo_id}/{filename} onto {device}")
184
+ return model
185
 
186
 
187
  if __name__ == "__main__":