← 返回首页
手写数字数据集MINST的读取及预处理
发表时间:2024-06-23 10:40:30
手写数字数据集MINST的读取及预处理

MNIST是一个手写数字集合,该数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据。

1.MNIST数据集简介

  1. 该数据集包含60,000个用于训练的示例和10,000个用于测试的示例。
  2. 数据集包含了0-9共10类手写数字图片,每张图片都做了尺寸归一化,都是28x28大小的灰度图。
  3. MNIST数据集包含四个部分: 训练集图像:train-images-idx3-ubyte.gz(9.9MB,包含60000个样本) 训练集标签:train-labels-idx1-ubyte.gz(29KB,包含60000个标签) 测试集图像:t10k-images-idx3-ubyte.gz(1.6MB,包含10000个样本) 测试集标签:t10k-labels-idx1-ubyte.gz(5KB,包含10000个标签)

首先要打开pkl文件,需要用到py包_pickle,这里的open要用rb,因为是要以二进制的方式读取文件。(我这里的pkl文件还是个压缩包,所以使用gzip打开)

import _pickle as cPickle
import gzip

f = gzip.open("data\mnist\mnist.pkl.gz",'rb')
training_data, validation_data, test_data = cPickle.load(f, encoding='bytes')

##训练集、验证集、测试集分别有50000,10000,10000张
print(len(training_data[0]))
print(len(validation_data[0]))
print(len(test_data[0]))

print(len(training_data)) #包含图像和标签两个维度
print(training_data[0][0].shape)#50000张784*1的图像
print(training_data[1].shape) #50000个数字标签

输出结果:

#训练集、验证集、测试集分别有50000,10000,10000张
50000
10000
10000
2
(784,)
(50000,)

2.数据的预处理

读取训练集的第一张图片看看。

img = training_data[0]
img = img.reshape(28,-1)
print(type(img))
plt.imshow(img)
plt.show()

将图像从行向量(1784)转换成列向量(7841),并将图像对应的数字也转换成numpy类型的列向量(10*1),数字对应的索引置1,其余位置则为0。

def vectorized_result(j):
    e = np.zeros((10, 1))
    e[j] = 1.0
    return e

training_inputs = [np.reshape(x, (784, 1)) for x in training_data[0]]
training_results = [vectorized_result(y) for y in training_data[1]]
training_data = list(zip(training_inputs, training_results))

validation_inputs = [np.reshape(x, (784, 1)) for x in validation_data[0]]
validation_data = list(zip(validation_inputs, validation_data[1]))

test_inputs = [np.reshape(x, (784, 1)) for x in test_data[0]]
test_data = list(zip(test_inputs, test_data[1]))

完整代码:

import matplotlib.pyplot as plt
import numpy as np
import _pickle as cPickle
import gzip


def vectorized_result(j):
    e = np.zeros((10, 1))
    e[j] = 1.0
    return e

f = gzip.open("data\mnist\mnist.pkl.gz",'rb')
training_data, validation_data, test_data = cPickle.load(f, encoding='bytes')

print(len(training_data[0]))
print(len(validation_data[0]))
print(len(test_data[0]))

print(len(training_data)) #包含图像和标签两个维度
print(training_data[0][0].shape)#50000张784*1的图像
print(training_data[1].shape) #50000个数字标签


training_inputs = [np.reshape(x, (784, 1)) for x in training_data[0]]
training_results = [vectorized_result(y) for y in training_data[1]]
training_data = list(zip(training_inputs, training_results))

validation_inputs = [np.reshape(x, (784, 1)) for x in validation_data[0]]
validation_data = list(zip(validation_inputs, validation_data[1]))

test_inputs = [np.reshape(x, (784, 1)) for x in test_data[0]]
test_data = list(zip(test_inputs, test_data[1]))


img = training_inputs[0]
img = img.reshape(28,-1)
print(type(img))
plt.imshow(img)
plt.show()