Spaces:
Sleeping
Sleeping
Stop downloading models
Browse files
app.py
CHANGED
@@ -5,7 +5,6 @@ import gradio as gr
|
|
5 |
import numpy as np
|
6 |
import os
|
7 |
import torch
|
8 |
-
import subprocess
|
9 |
import output
|
10 |
|
11 |
from rdkit import Chem
|
@@ -53,24 +52,15 @@ args = parser.parse_args()
|
|
53 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
54 |
print(f'Device: {device}')
|
55 |
os.makedirs("results", exist_ok=True)
|
56 |
-
os.makedirs("models", exist_ok=True)
|
57 |
|
58 |
size_gnn_path = 'models/geom_size_gnn.ckpt'
|
59 |
-
if not os.path.exists(size_gnn_path):
|
60 |
-
print('Downloading SizeGNN model...')
|
61 |
-
link = 'https://zenodo.org/record/7121300/files/geom_size_gnn.ckpt?download=1'
|
62 |
-
subprocess.run(f'wget {link} -O {size_gnn_path}', shell=True)
|
63 |
size_nn = SizeClassifier.load_from_checkpoint('models/geom_size_gnn.ckpt', map_location=device).eval().to(device)
|
64 |
print('Loaded SizeGNN model')
|
65 |
|
66 |
|
67 |
diffusion_models = {}
|
68 |
for model_name, metadata in MODELS_METADATA.items():
|
69 |
-
link = metadata['link']
|
70 |
diffusion_path = metadata['path']
|
71 |
-
if not os.path.exists(diffusion_path):
|
72 |
-
print(f'Downloading {model_name}...')
|
73 |
-
subprocess.run(f'wget {link} -O {diffusion_path}', shell=True)
|
74 |
diffusion_models[model_name] = DDPM.load_from_checkpoint(diffusion_path, map_location=device).eval().to(device)
|
75 |
print(f'Loaded model {model_name}')
|
76 |
|
|
|
5 |
import numpy as np
|
6 |
import os
|
7 |
import torch
|
|
|
8 |
import output
|
9 |
|
10 |
from rdkit import Chem
|
|
|
52 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
53 |
print(f'Device: {device}')
|
54 |
os.makedirs("results", exist_ok=True)
|
|
|
55 |
|
56 |
size_gnn_path = 'models/geom_size_gnn.ckpt'
|
|
|
|
|
|
|
|
|
57 |
size_nn = SizeClassifier.load_from_checkpoint('models/geom_size_gnn.ckpt', map_location=device).eval().to(device)
|
58 |
print('Loaded SizeGNN model')
|
59 |
|
60 |
|
61 |
diffusion_models = {}
|
62 |
for model_name, metadata in MODELS_METADATA.items():
|
|
|
63 |
diffusion_path = metadata['path']
|
|
|
|
|
|
|
64 |
diffusion_models[model_name] = DDPM.load_from_checkpoint(diffusion_path, map_location=device).eval().to(device)
|
65 |
print(f'Loaded model {model_name}')
|
66 |
|