本文最后更新于:14 天前
预处理数据:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2018/9/6 15:31
# @Author : Seven
# @Site :
# @File : Read_data.py
# @Software: PyCharm
# TODO: 加载数据
import pickle
import numpy as np
from sklearn.preprocessing import MinMaxScaler, LabelBinarizer
def load_cifar10_batch(path, batch_id):
"""
加载batch的数据
:param path: 数据存储的目录
:param batch_id:batch的编号
:return:features and labels
"""
with open(path + '/data_batch_' + str(batch_id), mode='rb') as file:
batch = pickle.load(file, encoding='latin1')
# features and labels
features = batch['data'].reshape((len(batch['data']), 3, 32, 32)).transpose(0, 2, 3, 1)
labels = batch['labels']
return features, labels
# 数据预处理
def pre_processing_data(x_train, y_train, x_test, y_test):
# features
minmax = MinMaxScaler()
# 重塑数据
# (50000, 32, 32, 3) --> (50000, 32*32*3)
x_train_rows = x_train.reshape(x_train.shape[0], 32*32*3)
# (10000, 32, 32, 3) --> (10000, 32*32*3)
x_test_rows = x_test.reshape(x_test.shape[0], 32*32*3)
# 归一化
x_train_norm = minmax.fit_transform(x_train_rows)
x_test_norm = minmax.fit_transform(x_test_rows)
# 重塑数据
x_train = x_train_norm.reshape(x_train_norm.shape[0], 32, 32, 3)
x_test = x_test_norm.reshape(x_test_norm.shape[0], 32, 32, 3)
# labels
# 对标签进行one-hot
n_class = 10
label_binarizer = LabelBinarizer().fit(np.array(range(n_class)))
y_train = label_binarizer.transform(y_train)
y_test = label_binarizer.transform(y_test)
return x_train, y_train, x_test, y_test
def cifar10_data():
# 加载训练数据
cifar10_path = 'data'
# 一共是有5个batch的训练数据
x_train, y_train = load_cifar10_batch(cifar10_path, 1)
for n in range(2, 6):
features, labels = load_cifar10_batch(cifar10_path, n)
x_train = np.concatenate([x_train, features])
y_train = np.concatenate([y_train, labels])
# 加载测试数据
with open(cifar10_path + '/test_batch', mode='rb') as file:
batch = pickle.load(file, encoding='latin1')
x_test = batch['data'].reshape((len(batch['data']), 3, 32, 32)).transpose(0, 2, 3, 1)
y_test = batch['labels']
x_train, y_train, x_test, y_test = pre_processing_data(x_train, y_train, x_test, y_test)
return x_train, y_train, x_test, y_test
配置文件:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2018/9/6 16:02
# @Author : Seven
# @Site :
# @File : config.py
# @Software: PyCharm
import tensorflow as tf
import matplotlib.pyplot as plt
# 初始化卷积神经网络参数
keep_prob = 0.8
epochs = 20
batch_size = 128
n_classes = 10 # 总共10类
# 定义输入和标签的placeholder
inputs = tf.placeholder(tf.float32, [None, 32, 32, 3], name='inputs')
targets = tf.placeholder(tf.float32, [None, 10], name='logits')
learning_rate = 0.001
# 显示图片
def show_images(images):
fig, axes = plt.subplots(nrows=3, ncols=3, sharex=True, sharey=True, figsize=(9, 9))
img = images[: 60]
for image, row in zip([img[: 20], img[20: 40], img[40: 60]], axes):
for img, ax in zip(image, row):
ax.imshow(img)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
fig.tight_layout(pad=0.1)
# plt.show()
# 存储alexnet所有的网络参数
weights = {
'wc1': tf.Variable(tf.random_normal(shape=[11, 11, 3, 96])),
'wc2': tf.Variable(tf.random_normal(shape=[5, 5, 96, 256])),
'wc3': tf.Variable(tf.random_normal(shape=[3, 3, 256, 384])),
'wc4': tf.Variable(tf.random_normal(shape=[3, 3, 384, 384])),
'wc5': tf.Variable(tf.random_normal(shape=[3, 3, 384, 256])),
'wd1': tf.Variable(tf.random_normal(shape=[4*4*256, 4096])),
'wd2': tf.Variable(tf.random_normal(shape=[4096, 1024])),
'out': tf.Variable(tf.random_normal(shape=[1024, n_classes]))
}
biases = {
'bc1': tf.Variable(tf.random_normal([96])),
'bc2': tf.Variable(tf.random_normal([256])),
'bc3': tf.Variable(tf.random_normal([384])),
'bc4': tf.Variable(tf.random_normal([384])),
'bc5': tf.Variable(tf.random_normal([256])),
'bd1': tf.Variable(tf.random_normal([4096])),
'bd2': tf.Variable(tf.random_normal([1024])),
'out': tf.Variable(tf.random_normal([n_classes]))
}
本博客所有文章除特别声明外,均采用 CC BY-SA 3.0协议 。转载请注明出处!