更强的模型
Browse files- cyclegan.py +19 -3
- model_data/G_model_B2A_last_epoch_weights.pth +1 -1
cyclegan.py
CHANGED
@@ -19,6 +19,10 @@ class CYCLEGAN(object):
|
|
19 |
#-----------------------------------------------#
|
20 |
"input_shape" : [112, 112],
|
21 |
#-------------------------------#
|
|
|
|
|
|
|
|
|
22 |
# 是否使用Cuda
|
23 |
# 没有GPU可以设置成False
|
24 |
#-------------------------------#
|
@@ -64,9 +68,14 @@ class CYCLEGAN(object):
|
|
64 |
#---------------------------------------------------------#
|
65 |
image = cvtColor(image)
|
66 |
#---------------------------------------------------------#
|
|
|
|
|
|
|
|
|
|
|
67 |
# 添加上batch_size维度
|
68 |
#---------------------------------------------------------#
|
69 |
-
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(
|
70 |
|
71 |
with torch.no_grad():
|
72 |
images = torch.from_numpy(image_data)
|
@@ -80,10 +89,17 @@ class CYCLEGAN(object):
|
|
80 |
#---------------------------------------------------#
|
81 |
# 转为numpy
|
82 |
#---------------------------------------------------#
|
83 |
-
pr = pr.permute(1, 2, 0).cpu().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
image = postprocess_output(pr)
|
86 |
-
image = np.clip(image, 0, 255)
|
87 |
image = Image.fromarray(np.uint8(image))
|
88 |
|
89 |
return image
|
|
|
19 |
#-----------------------------------------------#
|
20 |
"input_shape" : [112, 112],
|
21 |
#-------------------------------#
|
22 |
+
# 是否进行不失真的resize
|
23 |
+
#-------------------------------#
|
24 |
+
"letterbox_image" : True,
|
25 |
+
#-------------------------------#
|
26 |
# 是否使用Cuda
|
27 |
# 没有GPU可以设置成False
|
28 |
#-------------------------------#
|
|
|
68 |
#---------------------------------------------------------#
|
69 |
image = cvtColor(image)
|
70 |
#---------------------------------------------------------#
|
71 |
+
# 给图像增加灰条,实现不失真的resize
|
72 |
+
# 也可以直接resize进行识别
|
73 |
+
#---------------------------------------------------------#
|
74 |
+
image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image)
|
75 |
+
#---------------------------------------------------------#
|
76 |
# 添加上batch_size维度
|
77 |
#---------------------------------------------------------#
|
78 |
+
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
|
79 |
|
80 |
with torch.no_grad():
|
81 |
images = torch.from_numpy(image_data)
|
|
|
89 |
#---------------------------------------------------#
|
90 |
# 转为numpy
|
91 |
#---------------------------------------------------#
|
92 |
+
pr = pr.permute(1, 2, 0).cpu().numpy()
|
93 |
+
|
94 |
+
#--------------------------------------#
|
95 |
+
# 将灰条部分截取掉
|
96 |
+
#--------------------------------------#
|
97 |
+
if nw is not None:
|
98 |
+
pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
|
99 |
+
int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
|
100 |
+
|
101 |
|
102 |
image = postprocess_output(pr)
|
|
|
103 |
image = Image.fromarray(np.uint8(image))
|
104 |
|
105 |
return image
|
model_data/G_model_B2A_last_epoch_weights.pth
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 11888773
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1815cd8f77471a8712b9a80b20da4cd7afe7aad2b32ad48cd205d1c370a65dc2
|
3 |
size 11888773
|