from pre trained
Browse files
UModel.py
CHANGED
@@ -153,6 +153,35 @@ class UModel(nn.Module):
|
|
153 |
|
154 |
predstrs = [uyghur_latin.decode(pred) for pred in preds]
|
155 |
return predstrs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
|
158 |
if __name__ == "__main__":
|
|
|
153 |
|
154 |
predstrs = [uyghur_latin.decode(pred) for pred in preds]
|
155 |
return predstrs
|
156 |
+
|
157 |
+
@classmethod
|
158 |
+
def from_pretrained(cls, repo_id: str, filename: str = "best_model.pth", device: str = "cpu", **kwargs):
|
159 |
+
"""
|
160 |
+
Load a pretrained UModel from a Hugging Face repository.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
repo_id (str): The Hugging Face repository ID.
|
164 |
+
filename (str): The name of the checkpoint file in the repo.
|
165 |
+
device (str): The device to load the model onto ("cpu" or "cuda").
|
166 |
+
**kwargs: Additional arguments to pass to the UModel constructor.
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
UModel: An instance of UModel with pretrained weights loaded.
|
170 |
+
"""
|
171 |
+
# Download the checkpoint from the Hugging Face Hub
|
172 |
+
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
173 |
+
|
174 |
+
# Initialize the model
|
175 |
+
model = cls(**kwargs)
|
176 |
+
|
177 |
+
# Load state dict
|
178 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
179 |
+
model.load_state_dict(checkpoint["st_dict"])
|
180 |
+
|
181 |
+
# Set model to evaluation mode
|
182 |
+
model.eval()
|
183 |
+
print(f"Loaded model from {repo_id}/{filename} onto {device}")
|
184 |
+
return model
|
185 |
|
186 |
|
187 |
if __name__ == "__main__":
|