2019/6/16 Python入門②サンプルプログラム

1.MNISTデータセットの内容確認

まずはじめにMNISTデータセットの内部構造を確認してみます。

訓練用データ 6万件
テスト用データ 1万件

横27pixel 縦27pixel の256階調グレースケールの画像データと、 正解ラベル

からなります。

from tensorflow import keras
from matplotlib import pyplot as plt

mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

print("訓練データの形式: ", train_images.shape)
print("テストデータの形式: ", test_images.shape)

for i in range(0,25):
  print("ラベル[",i,"]:", train_labels[i])
  plt.subplot(5, 5, i+1)
  plt.imshow(train_images[i],'gray')

2.手書き数字認識用ニューラルネットワーク

from __future__ import absolute_import, division, print_function

import tensorflow as tf
from tensorflow import keras

import numpy as np
import matplotlib.pyplot as plt
import time

# TensorFlow のバージョンの出力
print(tf.__version__)

# MNISTデータのロード
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']

# 前処理
# 0~255階調のグレースケールのイメージを0~1にしてニューラルネットワークに最適化
train_images = train_images / 255.0
test_images = test_images / 255.0


#
# モデルの作成
#

# レイヤーのセットアップ
# 入力層: 28x28 = 784pixels
# 隠れ層:128ノード、活性化関数はReLU
# 出力層:10ノード、ソフトマックス関数(出力の和は1)
model = keras.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(128, activation=tf.nn.relu),
    keras.layers.Dense(10, activation=tf.nn.softmax)
])

# モデルのコンパイル(訓練過程の設定)
# Loss(損失)関数:スパースなNクラス分類交差エントロピー
# Optimizer(最適化)関数:adam(Adaptive Moment Estimation)
# Metrics:トレーニングとテストの手順を監視するために使われる。(認識)精度で表す(正しく分類された画像の割合)
model.compile(optimizer='adam', 
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'])

# モデルの訓練
# 訓練用画像データ、訓練用画像データの正解ラベル、訓練データの総数1エポックとして5エポックの訓練を行う
model.fit(train_images, train_labels, epochs=5)

# 精度の評価
# テストデータにより精度がどう変わるかを評価する
test_loss, test_acc = model.evaluate(test_images, test_labels)
# 訓練データによる過適合により多少精度は下がる

# テストデータによる推論処理
predictions = model.predict(test_images)

TensorFlowチュートリアル

はじめてのニューラルネットワーク:分類問題の初歩

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です