File size: 2,266 Bytes
ec24a39 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
# project utilities
import numpy as np
import struct
import matplotlib.pyplot as plt
from array import array
def load_image(path):
'''
This function reads an IDX train/test file images and returns the reshaped image
(m x 28 x 28 x 1) and normalized [0 - 1] NumPy array.
---------------------------------------------------------------------------------------
aguments:
path: (str) IDX file path
return:
reshape_img: np array of size (m x 28 x 28 x 1)
---------------------------------------------------------------------------------------
'''
with open(path, 'rb') as img_data:
magic, num, rows, cols = struct.unpack(">IIII", img_data.read(16))
images = np.frombuffer(img_data.read(), dtype = np.uint8).reshape(num, rows, cols)
reshape_img = images.reshape(-1, 28, 28, 1) / 255
return reshape_img
def load_label(path):
'''
This function reads an IDX train/test file labels and returns the data as a NumPy array
(m, 10) with 10 classes [0 - 9] and m examples.
----------------------------------------------------------------------------------------
aguments:
path: (str) IDX file path
return:
label: np array of size (m, 10)
---------------------------------------------------------------------------------------
'''
with open(path, 'rb') as lb_data:
magic, num = struct.unpack(">II", lb_data.read(8))
labels = np.array(array("B", lb_data.read()))
label = np.eye(10, dtype = int)[labels]
return label
def image_show(data, label, index):
'''
This function display the image and corresponding labels
---------------------------------------------------------------------------------------
arguments:
data: (features (ndarray)) train / test images to be displayed
label: (array) corresponding label, either from prediction or normal label
index: (int) specific example to be dispalyed
return:
the plot
'''
plt.figure(figsize=(5, 5))
plt.imshow(data[index], cmap = 'gray')
plt.title(f'It is {label[index]}')
plt.axis('off')
plt.show() |