{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "Untitled0.ipynb", "version": "0.3.2", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "GPU" }, "cells": [ { "metadata": { "id": "UBpbr4JZKYTz", "colab_type": "text" }, "cell_type": "markdown", "source": [ "# MNIST Digit Classification Using Recurrent Neural Networks" ] }, { "metadata": { "id": "CxeZAiQkLMNR", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 704 }, "outputId": "77ad5785-4e41-40d0-896e-ed9e6873b4ae" }, "cell_type": "code", "source": [ "import tensorflow as tf\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import argparse\n", "\n", "######################\n", "# Optimization Flags #\n", "######################\n", "\n", "learning_rate = 0.001 # initial learning rate\n", "seed = 111\n", "\n", "##################\n", "# Training Flags #\n", "##################\n", "batch_size = 128 # Batch size for training\n", "num_epoch = 10 # Number of training iterations\n", "\n", "###############\n", "# Model Flags #\n", "###############\n", "hidden_size = 128 # Number of neurons for RNN hodden layer\n", "\n", "# Reset the graph set the random numbers to be the same using \"seed\"\n", "tf.reset_default_graph()\n", "tf.set_random_seed(seed)\n", "np.random.seed(seed)\n", "\n", "# Divide 28x28 images to rows of data to feed to RNN as sequantial information\n", "step_size = 28\n", "input_size = 28\n", "output_size = 10\n", "\n", "# Input tensors\n", "X = tf.placeholder(tf.float32, [None, step_size, input_size])\n", "y = tf.placeholder(tf.int32, [None])\n", "\n", "# Rnn\n", "cell = tf.nn.rnn_cell.BasicRNNCell(num_units=hidden_size)\n", "output, state = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32)\n", "\n", "# Forward pass and loss calcualtion\n", "logits = tf.layers.dense(state, output_size)\n", "cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits)\n", "loss = tf.reduce_mean(cross_entropy)\n", "\n", "# optimizer\n", "optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)\n", "\n", "# Prediction\n", "prediction = tf.nn.in_top_k(logits, y, 1)\n", "accuracy = tf.reduce_mean(tf.cast(prediction, tf.float32))\n", "\n", "# input data\n", "from tensorflow.examples.tutorials.mnist import input_data\n", "mnist = input_data.read_data_sets(\"MNIST_data/\")\n", "\n", "# Process MNIST\n", "X_test = mnist.test.images # X_test shape: [num_test, 28*28]\n", "X_test = X_test.reshape([-1, step_size, input_size])\n", "y_test = mnist.test.labels\n", "\n", "# initialize the variables\n", "init = tf.global_variables_initializer()\n", "\n", "# Empty list for tracking\n", "loss_train_list = []\n", "acc_train_list = []\n", "\n", "# train the model\n", "with tf.Session() as sess:\n", " sess.run(init)\n", " n_batches = mnist.train.num_examples // batch_size\n", " for epoch in range(num_epoch):\n", " for batch in range(n_batches):\n", " X_train, y_train = mnist.train.next_batch(batch_size)\n", " X_train = X_train.reshape([-1, step_size, input_size])\n", " sess.run(optimizer, feed_dict={X: X_train, y: y_train})\n", " loss_train, acc_train = sess.run(\n", " [loss, accuracy], feed_dict={X: X_train, y: y_train})\n", " loss_train_list.append(loss_train)\n", " acc_train_list.append(acc_train)\n", " print('Epoch: {}, Train Loss: {:.3f}, Train Acc: {:.3f}'.format(\n", " epoch + 1, loss_train, acc_train))\n", " loss_test, acc_test = sess.run(\n", " [loss, accuracy], feed_dict={X: X_test, y: y_test})\n", " print('Test Loss: {:.3f}, Test Acc: {:.3f}'.format(loss_test, acc_test))\n" ], "execution_count": 2, "outputs": [ { "output_type": "stream", "text": [ "WARNING:tensorflow:From <ipython-input-2-dc6c3a05d58e>:56: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n", "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Please write your own downloading logic.\n", "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:252: _internal_retry.<locals>.wrap.<locals>.wrapped_fn (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Please use urllib or similar directly.\n", "Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.\n", "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Please use tf.data to implement this functionality.\n", "Extracting MNIST_data/train-images-idx3-ubyte.gz\n", "Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.\n", "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Please use tf.data to implement this functionality.\n", "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n", "Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.\n", "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n", "Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.\n", "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n", "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n", "Epoch: 1, Train Loss: 0.279, Train Acc: 0.898\n", "Epoch: 2, Train Loss: 0.124, Train Acc: 0.969\n", "Epoch: 3, Train Loss: 0.145, Train Acc: 0.977\n", "Epoch: 4, Train Loss: 0.231, Train Acc: 0.914\n", "Epoch: 5, Train Loss: 0.088, Train Acc: 0.961\n", "Epoch: 6, Train Loss: 0.104, Train Acc: 0.961\n", "Epoch: 7, Train Loss: 0.174, Train Acc: 0.961\n", "Epoch: 8, Train Loss: 0.099, Train Acc: 0.961\n", "Epoch: 9, Train Loss: 0.075, Train Acc: 0.961\n", "Epoch: 10, Train Loss: 0.081, Train Acc: 0.969\n", "Test Loss: 0.124, Test Acc: 0.965\n" ], "name": "stdout" } ] }, { "metadata": { "id": "nkPppIILLN5Z", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "" ], "execution_count": 0, "outputs": [] } ] }