Spaces:
Runtime error
Runtime error
Update gen.py
Browse files
gen.py
CHANGED
@@ -63,7 +63,7 @@ def get_pretrained_models(
|
|
63 |
|
64 |
llama_ckpt_path = checkpoints[local_rank]
|
65 |
print("Loading")
|
66 |
-
checkpoint = torch.load(llama_ckpt_path, map_location=
|
67 |
with open(Path(llama_weight_path) / "params.json", "r") as f:
|
68 |
params = json.loads(f.read())
|
69 |
|
|
|
63 |
|
64 |
llama_ckpt_path = checkpoints[local_rank]
|
65 |
print("Loading")
|
66 |
+
checkpoint = torch.load(llama_ckpt_path, map_location="cpu")
|
67 |
with open(Path(llama_weight_path) / "params.json", "r") as f:
|
68 |
params = json.loads(f.read())
|
69 |
|