|
|
|
|
|
import numpy as np
|
|
import struct
|
|
import matplotlib.pyplot as plt
|
|
from array import array
|
|
|
|
def load_image(path):
|
|
'''
|
|
This function reads an IDX train/test file images and returns the reshaped image
|
|
(m x 28 x 28 x 1) and normalized [0 - 1] NumPy array.
|
|
---------------------------------------------------------------------------------------
|
|
aguments:
|
|
path: (str) IDX file path
|
|
return:
|
|
reshape_img: np array of size (m x 28 x 28 x 1)
|
|
---------------------------------------------------------------------------------------
|
|
'''
|
|
with open(path, 'rb') as img_data:
|
|
magic, num, rows, cols = struct.unpack(">IIII", img_data.read(16))
|
|
images = np.frombuffer(img_data.read(), dtype = np.uint8).reshape(num, rows, cols)
|
|
reshape_img = images.reshape(-1, 28, 28, 1) / 255
|
|
return reshape_img
|
|
|
|
|
|
def load_label(path):
|
|
'''
|
|
This function reads an IDX train/test file labels and returns the data as a NumPy array
|
|
(m, 10) with 10 classes [0 - 9] and m examples.
|
|
----------------------------------------------------------------------------------------
|
|
aguments:
|
|
path: (str) IDX file path
|
|
return:
|
|
label: np array of size (m, 10)
|
|
---------------------------------------------------------------------------------------
|
|
'''
|
|
with open(path, 'rb') as lb_data:
|
|
magic, num = struct.unpack(">II", lb_data.read(8))
|
|
labels = np.array(array("B", lb_data.read()))
|
|
label = np.eye(10, dtype = int)[labels]
|
|
return label
|
|
|
|
def image_show(data, label, index):
|
|
'''
|
|
This function display the image and corresponding labels
|
|
---------------------------------------------------------------------------------------
|
|
arguments:
|
|
data: (features (ndarray)) train / test images to be displayed
|
|
label: (array) corresponding label, either from prediction or normal label
|
|
index: (int) specific example to be dispalyed
|
|
return:
|
|
the plot
|
|
'''
|
|
plt.figure(figsize=(5, 5))
|
|
plt.imshow(data[index], cmap = 'gray')
|
|
plt.title(f'It is {label[index]}')
|
|
plt.axis('off')
|
|
plt.show() |