秋山翔 commited on
Commit
51a7dff
1 Parent(s): a3348bb

TEST: debug gradio not rendering

Browse files
Files changed (2) hide show
  1. app.py +28 -26
  2. main.py +0 -76
app.py CHANGED
@@ -16,36 +16,38 @@ STYLE = "shinkai_makoto"
16
  MODEL_PATH = "models"
17
  COLOUR_MODEL = "RGB"
18
 
19
- model = Transformer()
20
- model.load_state_dict(torch.load(os.path.join(MODEL_PATH, f"{STYLE}.pth")))
21
- model.eval()
22
 
23
- disable_gpu = torch.cuda.is_available()
24
 
25
 
26
  def inference(img):
27
- # load image
28
- input_image = img.convert(COLOUR_MODEL)
29
- input_image = np.asarray(input_image)
30
- # RGB -> BGR
31
- input_image = input_image[:, :, [2, 1, 0]]
32
- input_image = transforms.ToTensor()(input_image).unsqueeze(0)
33
- # preprocess, (-1, 1)
34
- input_image = -1 + 2 * input_image
35
-
36
- if disable_gpu:
37
- input_image = Variable(input_image).float()
38
- else:
39
- input_image = Variable(input_image).cuda()
40
-
41
- # forward
42
- output_image = model(input_image)
43
- output_image = output_image[0]
44
- # BGR -> RGB
45
- output_image = output_image[[2, 1, 0], :, :]
46
- output_image = output_image.data.cpu().float() * 0.5 + 0.5
47
-
48
- return output_image
 
 
49
 
50
 
51
  title = "AnimeBackgroundGAN"
 
16
  MODEL_PATH = "models"
17
  COLOUR_MODEL = "RGB"
18
 
19
+ # model = Transformer()
20
+ # model.load_state_dict(torch.load(os.path.join(MODEL_PATH, f"{STYLE}.pth")))
21
+ # model.eval()
22
 
23
+ # disable_gpu = torch.cuda.is_available()
24
 
25
 
26
  def inference(img):
27
+ # # load image
28
+ # input_image = img.convert(COLOUR_MODEL)
29
+ # input_image = np.asarray(input_image)
30
+ # # RGB -> BGR
31
+ # input_image = input_image[:, :, [2, 1, 0]]
32
+ # input_image = transforms.ToTensor()(input_image).unsqueeze(0)
33
+ # # preprocess, (-1, 1)
34
+ # input_image = -1 + 2 * input_image
35
+
36
+ # if disable_gpu:
37
+ # input_image = Variable(input_image).float()
38
+ # else:
39
+ # input_image = Variable(input_image).cuda()
40
+
41
+ # # forward
42
+ # output_image = model(input_image)
43
+ # output_image = output_image[0]
44
+ # # BGR -> RGB
45
+ # output_image = output_image[[2, 1, 0], :, :]
46
+ # output_image = output_image.data.cpu().float() * 0.5 + 0.5
47
+
48
+ # return output_image
49
+
50
+ return ""
51
 
52
 
53
  title = "AnimeBackgroundGAN"
main.py DELETED
@@ -1,76 +0,0 @@
1
- import torch
2
- import os
3
- import numpy as np
4
- import torchvision.utils as vutils
5
-
6
- from PIL import Image
7
- import torchvision.transforms as transforms
8
- from torch.autograd import Variable
9
-
10
- from network.Transformer import Transformer
11
- import argparse
12
-
13
- parser = argparse.ArgumentParser()
14
- parser.add_argument("--input_dir", default="test_img")
15
- parser.add_argument("--load_size", default=1280)
16
- parser.add_argument("--model_path", default="./pretrained_model")
17
- parser.add_argument("--style", default="Shinkai")
18
- parser.add_argument("--output_dir", default="test_output")
19
- parser.add_argument("--gpu", type=int, default=0)
20
-
21
- opt = parser.parse_args()
22
-
23
- valid_ext = [".jpg", ".png"]
24
-
25
- # setup
26
- if not os.path.exists(opt.input_dir):
27
- os.makedirs(opt.input_dir)
28
- if not os.path.exists(opt.output_dir):
29
- os.makedirs(opt.output_dir)
30
-
31
- # load pretrained model
32
- model = Transformer()
33
- model.load_state_dict(
34
- torch.load(os.path.join(opt.model_path, opt.style + "_net_G_float.pth"))
35
- )
36
- model.eval()
37
-
38
- disable_gpu = opt.gpu == -1 or not torch.cuda.is_available()
39
-
40
- if disable_gpu:
41
- print("CPU mode")
42
- model.float()
43
- else:
44
- print("GPU mode")
45
- model.cuda()
46
-
47
- for files in os.listdir(opt.input_dir):
48
- ext = os.path.splitext(files)[1]
49
- if ext not in valid_ext:
50
- continue
51
- # load image
52
- input_image = Image.open(os.path.join(opt.input_dir, files)).convert("RGB")
53
- input_image = np.asarray(input_image)
54
- # RGB -> BGR
55
- input_image = input_image[:, :, [2, 1, 0]]
56
- input_image = transforms.ToTensor()(input_image).unsqueeze(0)
57
- # preprocess, (-1, 1)
58
- input_image = -1 + 2 * input_image
59
- if disable_gpu:
60
- input_image = Variable(input_image).float()
61
- else:
62
- input_image = Variable(input_image).cuda()
63
-
64
- # forward
65
- output_image = model(input_image)
66
- output_image = output_image[0]
67
- # BGR -> RGB
68
- output_image = output_image[[2, 1, 0], :, :]
69
- output_image = output_image.data.cpu().float() * 0.5 + 0.5
70
- # save
71
- vutils.save_image(
72
- output_image,
73
- os.path.join(opt.output_dir, files[:-4] + "_" + opt.style + ".jpg"),
74
- )
75
-
76
- print("Done!")