RNN, Recurrent Neural Networks 进行分类(classification),采用 MNIST 数据集,用 SimpleRNN 层。
1 | import numpy as np |
1. data pre-processing
MNIST里面的图像分辨率是28×28,为用RNN,将图像理解为序列化数据。
每一行作为一个输入单元,所以输入数据大小 INPUT_SIZE = 28
;
先是第1行输入,再是第2行,…,第28行输入, 这就是一张图片也就是一个序列,所以步长 TIME_STEPS = 28
。
训练数据要进行 normalize,因为原始数据是 8bit 灰度图像, 所以需要除以 255。
1 | # download the mnist to the path '~/.keras/datasets/' if it is the first time to be called |
1 | print(X_train.shape) |
2. build model
1 | # build RNN model |
1 | # output layer |
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
simple_rnn_1 (SimpleRNN) (None, 50) 3950
_________________________________________________________________
dense_1 (Dense) (None, 10) 510
_________________________________________________________________
activation_1 (Activation) (None, 10) 0
_________________________________________________________________
dense_2 (Dense) (None, 10) 110
_________________________________________________________________
activation_2 (Activation) (None, 10) 0
=================================================================
Total params: 4,570
Trainable params: 4,570
Non-trainable params: 0
_________________________________________________________________
设置优化方法,loss函数 和 metrics
方法之后就可以开始训练了。 每次训练的时候并不是取所有的数据,只是取 BATCH_SIZE
个序列,或者称为 BATCH_SIZE
张图片,这样可以大大降低运算时间,提高训练效率。
3. training & evaluate
输出 test 上的 loss
和 accuracy
结果
1 | # training |
test cost: 2.311124086380005 test accuracy: 0.0957999974489212
test cost: 1.6327736377716064 test accuracy: 0.5228999853134155
test cost: 1.3161704540252686 test accuracy: 0.559499979019165
test cost: 1.1487971544265747 test accuracy: 0.5494999885559082
test cost: 1.0471760034561157 test accuracy: 0.5713000297546387
test cost: 1.0110148191452026 test accuracy: 0.5630999803543091
test cost: 0.9520753622055054 test accuracy: 0.5877000093460083
test cost: 0.8796814680099487 test accuracy: 0.604200005531311
test cost: 0.858435869216919 test accuracy: 0.6585999727249146
Checking if Disqus is accessible...