『Tensorflow入门』DNNClassifier

04-03   2  391

MNIST数据集处理是机器学习新手绕不开的问题,但Tensorflow官方给出了一个更为简单易懂的例子——Iris品种分类问题。

以下代码已详细注释,作为学习记录,若有纰漏,还请指正!


main.py:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import tensorflow as tf
import iris_data

# 设置命令行可接受参数
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default=100, type=int, help='batch size')
parser.add_argument('--train_steps', default=1000, type=int,
                    help='number of training steps')


def main(argv):
    # 读取命令行参数
    args = parser.parse_args(argv[1:])

    # 获取数据
    (train_x, train_y), (test_x, test_y) = iris_data.load_data()

    # 描述如何使用输入数据的特征列
    my_feature_columns = []
    for key in train_x.keys():
        my_feature_columns.append(tf.feature_column.numeric_column(key=key))

    # 建立3个各含100个节点的隐藏层的DNN
    classifier = tf.estimator.DNNClassifier(
        feature_columns=my_feature_columns,
        # 隐藏层以及各层节点
        hidden_units=[100, 100, 100],
        # 标签种类
        n_classes=3)

    # 训练模型
    classifier.train(
        input_fn=lambda: iris_data.train_input_fn(train_x, train_y, args.batch_size),
        steps=args.train_steps)

    print('训练完毕...')

    # 评估模型
    eval_result = classifier.evaluate(
        input_fn=lambda: iris_data.eval_input_fn(test_x, test_y, args.batch_size))

    print('\n评估完毕...')
    print('测试集准确率: {accuracy:0.4f}\n'.format(**eval_result))

    # 预测数据以及期望值
    expected = ['Setosa', 'Versicolor', 'Virginica']
    predict_x = {
        'SepalLength': [5.1, 5.9, 6.9],
        'SepalWidth': [3.3, 3.0, 3.1],
        'PetalLength': [1.7, 4.2, 5.4],
        'PetalWidth': [0.5, 1.5, 2.1],
    }

    # 预测
    predictions = classifier.predict(
        input_fn=lambda: iris_data.eval_input_fn(predict_x, labels=None,
                                                 batch_size=args.batch_size))

    template = '预测为:"{}" ({:.4f}%), 期望为:"{}"'

    for pred_dict, expec in zip(predictions, expected):
        class_id = pred_dict['class_ids'][0]
        probability = pred_dict['probabilities'][class_id]

        print(template.format(iris_data.SPECIES[class_id], 100 * probability, expec))


if __name__ == '__main__':
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.app.run(main)


iris_data.py:

import pandas as pd
import tensorflow as tf

TRAIN_URL = "http://download.tensorflow.org/data/iris_training.csv"
TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"

CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth',
                    'PetalLength', 'PetalWidth', 'Species']
SPECIES = ['Setosa', 'Versicolor', 'Virginica']


def maybe_download():
    """加载文件,如果本地不存在则下载"""
    train_path = tf.keras.utils.get_file(TRAIN_URL.split('/')[-1], TRAIN_URL)
    test_path = tf.keras.utils.get_file(TEST_URL.split('/')[-1], TEST_URL)

    return train_path, test_path


def load_data(y_name='Species'):
    """以 (train_x, train_y), (test_x, test_y) 的格式返回 iris 数据集"""
    train_path, test_path = maybe_download()

    train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)
    train_x, train_y = train, train.pop(y_name)

    test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)
    test_x, test_y = test, test.pop(y_name)

    return (train_x, train_y), (test_x, test_y)


def train_input_fn(features, labels, batch_size):
    """训练数据的输入函数"""
    # 将输入数据转为数据集(dataset)
    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

    # 打乱数据集顺序、重复执行,再批处理数据
    dataset = dataset.shuffle(1000).repeat().batch(batch_size)

    # 返回数据集
    return dataset


def eval_input_fn(features, labels, batch_size):
    """评估和预测的输入函数"""
    features = dict(features)
    if labels is None:
        # 如果没有设置标签,则用特征代替
        inputs = features
    else:
        inputs = (features, labels)

    # 将输入数据转为数据集(dataset)
    dataset = tf.data.Dataset.from_tensor_slices(inputs)

    # 批处理
    assert batch_size is not None, "需指定一个值"
    dataset = dataset.batch(batch_size)

    # 返回数据集
    return dataset


评论 ( {{ comments.total }} )
{{ o.content }}
赞 {{ o.likes_count ? o.likes_count : '' }} 回复 {{ o.created_at }}
作者信息