MNIST是一个手写数字集合,该数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据。
首先要打开pkl文件,需要用到py包_pickle,这里的open要用rb,因为是要以二进制的方式读取文件。(我这里的pkl文件还是个压缩包,所以使用gzip打开)
import _pickle as cPickle
import gzip
f = gzip.open("data\mnist\mnist.pkl.gz",'rb')
training_data, validation_data, test_data = cPickle.load(f, encoding='bytes')
##训练集、验证集、测试集分别有50000,10000,10000张
print(len(training_data[0]))
print(len(validation_data[0]))
print(len(test_data[0]))
print(len(training_data)) #包含图像和标签两个维度
print(training_data[0][0].shape)#50000张784*1的图像
print(training_data[1].shape) #50000个数字标签
输出结果:
#训练集、验证集、测试集分别有50000,10000,10000张
50000
10000
10000
2
(784,)
(50000,)
读取训练集的第一张图片看看。
img = training_data[0]
img = img.reshape(28,-1)
print(type(img))
plt.imshow(img)
plt.show()

将图像从行向量(1784)转换成列向量(7841),并将图像对应的数字也转换成numpy类型的列向量(10*1),数字对应的索引置1,其余位置则为0。
def vectorized_result(j):
e = np.zeros((10, 1))
e[j] = 1.0
return e
training_inputs = [np.reshape(x, (784, 1)) for x in training_data[0]]
training_results = [vectorized_result(y) for y in training_data[1]]
training_data = list(zip(training_inputs, training_results))
validation_inputs = [np.reshape(x, (784, 1)) for x in validation_data[0]]
validation_data = list(zip(validation_inputs, validation_data[1]))
test_inputs = [np.reshape(x, (784, 1)) for x in test_data[0]]
test_data = list(zip(test_inputs, test_data[1]))
完整代码:
import matplotlib.pyplot as plt
import numpy as np
import _pickle as cPickle
import gzip
def vectorized_result(j):
e = np.zeros((10, 1))
e[j] = 1.0
return e
f = gzip.open("data\mnist\mnist.pkl.gz",'rb')
training_data, validation_data, test_data = cPickle.load(f, encoding='bytes')
print(len(training_data[0]))
print(len(validation_data[0]))
print(len(test_data[0]))
print(len(training_data)) #包含图像和标签两个维度
print(training_data[0][0].shape)#50000张784*1的图像
print(training_data[1].shape) #50000个数字标签
training_inputs = [np.reshape(x, (784, 1)) for x in training_data[0]]
training_results = [vectorized_result(y) for y in training_data[1]]
training_data = list(zip(training_inputs, training_results))
validation_inputs = [np.reshape(x, (784, 1)) for x in validation_data[0]]
validation_data = list(zip(validation_inputs, validation_data[1]))
test_inputs = [np.reshape(x, (784, 1)) for x in test_data[0]]
test_data = list(zip(test_inputs, test_data[1]))
img = training_inputs[0]
img = img.reshape(28,-1)
print(type(img))
plt.imshow(img)
plt.show()