import numpy as np import matplotlib.pyplot as plt import seaborn as sns; sns.set() def visualize(s, batch, prefix): if len(s.shape) == 5: x, b, m = batch['x'], batch['b'], batch['m'] im_visualize(s, x, b, m, prefix) elif len(s.shape) == 3: x, b, m = batch['x'], batch['b'], batch['m'] pc_visualize(s, x, b, m, prefix) elif len(s.shape) == 4: xc, yc, xt, yt = batch['xc'], batch['yc'], batch['xt'], batch['yt'] fn_visualize(s, xc, yc, xt, yt, prefix) else: raise ValueError() def im_visualize(s, x, b, m, prefix): B,N,H,W,C = s.shape for i in range(B): ss, xx, bb, mm = s[i], x[i], b[i], m[i] if ss.shape[-1] == 2: # kspace C = 1 ss = np.expand_dims(np.absolute(np.fft.ifft2(np.fft.ifftshift(ss[...,0] + ss[...,1] * 1j, axes=(-2,-1)))), axis=-1) ss = np.array(ss*255, dtype=np.uint8) xx = np.expand_dims(np.absolute(np.fft.ifft2(np.fft.ifftshift(xx[...,0] + xx[...,1] * 1j, axes=(-2,-1)))), axis=-1) xx = np.array(xx*255, dtype=np.uint8) bb = bb[...,0:1] mm = mm[...,0:1] ss = np.transpose(ss, [1,0,2,3]).reshape(H,W*N,C).squeeze() xx = np.transpose(xx, [1,0,2,3]).reshape(H,W*N,C).squeeze() bb = np.transpose(bb, [1,0,2,3]).reshape(H,W*N,C).squeeze() mm = np.transpose(mm, [1,0,2,3]).reshape(H,W*N,C).squeeze() xm = xx * mm + (1-mm) * 128 xo = xx * bb + (1-bb) * 128 img = np.concatenate([xm, xo, ss]).astype(np.uint8) plt.imsave(f'{prefix}_{i}.png', img) def pc_visualize(s, x, b, m, prefix): B,N,C = s.shape for i in range(B): ss, xx, bb = s[i], x[i], b[i] o = np.where(bb[:,0]==1)[0] fig = plt.figure(figsize=(7.5, 2.5)) ax = fig.add_subplot(131, projection='3d') ax.scatter(xx[:,0], xx[:,1], xx[:,2], c='g', s=5) ax.axis('off') ax.grid(False) ax = fig.add_subplot(132, projection='3d') ax.scatter(xx[o,0], xx[o,1], xx[o,2], c='g', s=5) ax.axis('off') ax.grid(False) ax = fig.add_subplot(133, projection='3d') ax.scatter(ss[:,0], ss[:,1], ss[:,2], c='g', s=5) ax.axis('off') ax.grid(False) plt.savefig(f'{prefix}_{i}.png') plt.close('all') def fn_visualize(s, xc, yc, xt, yt, prefix): B,K,N,C = s.shape for i in range(B): ss, xxc, yyc, xxt, yyt = s[i], xc[i], yc[i], xt[i], yt[i] fig = plt.figure(figsize=(4.0, 2.5*K)) for k in range(K): ax = fig.add_subplot(K,1,k+1) ax.plot(xxc[k], yyc[k], 'rx', markersize=8) ax.plot(xxt[k], yyt[k], 'ko', markersize=3) ax.plot(xxt[k], ss[k], 'bo', markersize=3) plt.savefig(f'{prefix}_{i}.png') plt.close('all') def plot_functions(m, s, batch, prefix): B,K,N,C = m.shape xc, yc, xt, yt = batch['xc'], batch['yc'], batch['xt'], batch['yt'] for i in range(B): mm, ss, xxc, yyc, xxt, yyt = m[i,:,:,0], s[i,:,:,0], xc[i,:,:,0], yc[i,:,:,0], xt[i,:,:,0], yt[i,:,:,0] fig = plt.figure(figsize=(4.0, 2.5*K)) for k in range(K): idx = np.argsort(xxt[k]) ax = fig.add_subplot(K,1,k+1) ax.plot(xxc[k], yyc[k], 'rx', markersize=8) ax.plot(xxt[k], yyt[k], 'ko', markersize=3) ax.plot(xxt[k,idx], mm[k,idx], 'b', linewidth=2) plt.fill_between( xxt[k,idx], mm[k,idx] - ss[k,idx], mm[k,idx] + ss[k,idx], alpha=0.2, facecolor='#65c9f7', interpolate=True) plt.savefig(f'{prefix}_{i}.png') plt.close('all') def plot_img_functions(m, s, batch, prefix): B,K,N,C = m.shape idx, xc, yc, xt, yt = batch['idx'], batch['xc'], batch['yc'], batch['xt'], batch['yt'] yo = np.ones_like(yt) * 128 yo[:,:,idx] = (yc + 0.5) * 255. yt = (yt + 0.5) * 255. m = (m + 0.5) * 255. for i in range(B): yoi, yti, mi = yo[i], yt[i], m[i] yoi = np.reshape(yoi, [K,28,28]).astype(np.uint8) yoi = np.reshape(np.transpose(yoi, [1,0,2]), [28, K*28]) yti = np.reshape(yti, [K,28,28]).astype(np.uint8) yti = np.reshape(np.transpose(yti, [1,0,2]), [28, K*28]) mi = np.reshape(mi, [K,28,28]).astype(np.uint8) mi = np.reshape(np.transpose(mi, [1,0,2]), [28, K*28]) img = np.concatenate([yoi, mi, yti], axis=0) plt.imsave(f'{prefix}_{i}.png', img)