Spaces:
Running
on
Zero
Running
on
Zero
fix
Browse files- scripts/anime.py +1 -1
- scripts/data.py +1 -1
- scripts/model.py +1 -4
scripts/anime.py
CHANGED
@@ -19,7 +19,7 @@ model = None
|
|
19 |
def init_model(use_local=False):
|
20 |
global model
|
21 |
model_opt = "default"
|
22 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
23 |
model = create_model(model_opt, use_local).to(device)
|
24 |
model.eval()
|
25 |
|
|
|
19 |
def init_model(use_local=False):
|
20 |
global model
|
21 |
model_opt = "default"
|
22 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # issue: nevetherless, use_gpu is False, it still uses GPU
|
23 |
model = create_model(model_opt, use_local).to(device)
|
24 |
model.eval()
|
25 |
|
scripts/data.py
CHANGED
@@ -40,7 +40,7 @@ def get_transform(load_size=0, grayscale=False, method=bic, convert=True):
|
|
40 |
transform_list.append(transforms.Grayscale(1))
|
41 |
if load_size > 0:
|
42 |
osize = [load_size, load_size]
|
43 |
-
transform_list.append(transforms.Resize(osize, method))
|
44 |
if convert:
|
45 |
# transform_list += [transforms.ToTensor()]
|
46 |
if grayscale:
|
|
|
40 |
transform_list.append(transforms.Grayscale(1))
|
41 |
if load_size > 0:
|
42 |
osize = [load_size, load_size]
|
43 |
+
transform_list.append(transforms.Resize(osize, method, antialias=False))
|
44 |
if convert:
|
45 |
# transform_list += [transforms.ToTensor()]
|
46 |
if grayscale:
|
scripts/model.py
CHANGED
@@ -154,8 +154,7 @@ def create_model(model, use_local):
|
|
154 |
|
155 |
import os
|
156 |
if model == 'default':
|
157 |
-
model_path = (lambda filename, subfolder: os.path.join(subfolder, filename) if use_local else download_file(filename, subfolder))
|
158 |
-
("netG.pth", "models/Anime2Sketch")
|
159 |
# model_path = ((filename, subfolder) => if (use_local) os.path.join(subfolder, filename) else download_file(filename, subfolder))("netG.pth", "models/Anime2Sketch") // JavaScript
|
160 |
|
161 |
ckpt = torch.load(model_path)
|
@@ -176,8 +175,6 @@ def create_model(model, use_local):
|
|
176 |
base = base.model[3]
|
177 |
|
178 |
net.load_state_dict(ckpt)
|
179 |
-
|
180 |
-
os.chdir(cwd) # 元のディレクトリに戻る
|
181 |
|
182 |
else:
|
183 |
raise ValueError(f"model should be one of ['default', 'improved'], but got {model}")
|
|
|
154 |
|
155 |
import os
|
156 |
if model == 'default':
|
157 |
+
model_path = (lambda filename, subfolder: os.path.join(subfolder, filename) if use_local else download_file(filename, subfolder))("netG.pth", "models/Anime2Sketch")
|
|
|
158 |
# model_path = ((filename, subfolder) => if (use_local) os.path.join(subfolder, filename) else download_file(filename, subfolder))("netG.pth", "models/Anime2Sketch") // JavaScript
|
159 |
|
160 |
ckpt = torch.load(model_path)
|
|
|
175 |
base = base.model[3]
|
176 |
|
177 |
net.load_state_dict(ckpt)
|
|
|
|
|
178 |
|
179 |
else:
|
180 |
raise ValueError(f"model should be one of ['default', 'improved'], but got {model}")
|