博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Tensorflow catdog-checkpoint
阅读量:7114 次
发布时间:2019-06-28

本文共 5188 字,大约阅读时间需要 17 分钟。

import tensorflow as tfimport globimport numpy as np复制代码
/anaconda3/envs/py35/lib/python3.5/importlib/_bootstrap.py:222: RuntimeWarning: compiletime version 3.6 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.5  return f(*args, **kwds)/anaconda3/envs/py35/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.  from ._conv import register_converters as _register_converters复制代码
image_filenames = glob.glob('./dc/train/*.jpg')复制代码
image_filenames[0]复制代码
---------------------------------------------------------------------------IndexError                                Traceback (most recent call last)
in
()----> 1 image_filenames[0]IndexError: list index out of range复制代码
lables = list(map(lambda x : x.split('\\')[1].split('.')[0], image_filenames))复制代码
train_lable = [[1, 0] if x == 'cat' else [0, 1] for x in lables]复制代码
image_que = tf.train.slice_input_producer([image_filenames, train_lable])复制代码
image_ = tf.read_file(image_que[0])image = tf.image.decode_jpeg(image_, channels = 3)复制代码
grey_image = tf.image.rgb_to_grayscale(image)resize_image = tf.image_resize_images(grey_image, (200, 200))resize_image = tf.reshape(resize_image, [200, 200, 1])复制代码
new_img = tf.image.per_image_standardization(resize_image)复制代码
batch_size = 60capacity = 10 + 2 * batch_size复制代码
image_batch, lable_batch = tf.train.batch([new_img, image_que[1]], batch_size=batch_size, capacity=capacity)复制代码
lable_batch.get_shape()复制代码
conv2d_1 = tf.contrib.layers.convolution2d(    image_batch,    num_outputs=32,    weights_initializer = tf.truncated_normal_initializer(stddev=0.001),    kernel_size = (5,5),    activation_fn = tf.nn.relu,    stride = (1,1),    padding = 'SAME',    trainable = True)pool_1 = tf.nn.max_pool(conv2d_1,                       ksize = [1,3,3,1],                       strides = [1,2,2,1],                       padding='SAME')复制代码
conv2d_2 = tf.contrib.layers.convolution2d(    pool_1,    num_outputs=32,    weights_initializer = tf.truncated_normal_initializer(stddev=0.01),    kernel_size = (5,5),    activation_fn = tf.nn.relu,    stride = (1,1),    padding = 'SAME',    trainable = True)pool_2 = tf.nn.max_pool(conv2d_2,                       ksize = [1,3,3,1],                       strides = [1,2,2,1],                       padding='SAME')复制代码
conv2d_3 = tf.contrib.layers.convolution2d(    pool_2,    num_outputs=64,    weights_initializer = tf.truncated_normal_initializer(stddev=0.01),    kernel_size = (5,5),    activation_fn = tf.nn.relu,    stride = (1,1),    padding = 'SAME',    trainable = True)pool_3 = tf.nn.max_pool(conv2d_3,                       ksize = [1,3,3,1],                       strides = [1,2,2,1],                       padding='SAME')复制代码
pool_3.get_shape()复制代码
pool3_flat = tf.reshape(pool_3, [-1, 25*25*64])fc_1 = tf.contrib.layers.fully_connected(                            pool3_flat,                             1024,                             weights_initializer = tf.truncated_normal_initializer(stddev=0.1),                            activation_fn = tf.nn.relu)复制代码
fc_2 = tf.contrib.layers.fully_connected(                            fc_1,                             192,                             weights_initializer = tf.truncated_normal_initializer(stddev=0.1),                            activation_fn = tf.nn.relu)复制代码
out_wl = tf.Variable(tf.truncated_normal([192, 2]))out_bl = tf.Variable(tf.truncated_normal([2]))comb_out = tf.matmul(fc_2, out_wl) + out_blpred = tf.sigmoid(comb_out)复制代码
pred.get_shape()复制代码
lable_batch.get_shape()复制代码
lable_batch = tf.cast(lable_batch, tf.float32)复制代码
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = lable_batch, logits = comb_out))复制代码
train_step = tf.train.AdamOptimizer(0.0001).minimize(loss)复制代码
predicted = tf.cast(pred >0.5, tf.float32)accuracy = tf.reduce_mean(tf.cast(tf.equal(predicted, lable_batch), tf.float32))复制代码
saver = tf.train.Saver()with tf.Session() as sess:    coord = tf.train.Coordinator()    threads = tf.train.start_queue_runners(coord = coord)    sess.run(tf.global_variables_initializer())    for step in range(0, 3000):        sess.run(train_step)        if(step %100 == 0):            res = sess.run([loss, accuracy])            print(step, res)            saver.save(sess, './lesson30', global_step = step)    coord.request_stop()    coord.join(threads)复制代码

恢复检查点

#import os#ckpt = tf.train_get_checkpoint_state(os.path.dirname('__file__'))#saver = tf.train.Saver()#sess = tf.Session()#sess.run(tf.global_variables_initializer())#saver.restore(sess, ckpt.model_checkpoint_path)复制代码
#coord = tf.train.Coordinator()#threads = tf.train.start_queue_runners(coord = coord)#for step in range(6000, 7000):#        sess.run(train_step, feed_dict={keep_prob:0.5})#        if(step %100 == 0):#            res = sess.run([loss, accuracy], feed_dict={keep_prob:1})#            print(step, res)#            saver.save(sess, './lesson30', global_step = step)#coord.request_stop()#coord.join(threads)复制代码

转载地址:http://zczel.baihongyu.com/

你可能感兴趣的文章
vim配置
查看>>
[Swift]UIKit学习之滑块控件UISlider的用法
查看>>
我的友情链接
查看>>
nginx+tomcat+memcached构建session共享集群
查看>>
回看Java环境变量classpath
查看>>
mysql数据库Explain详解 .
查看>>
python 多线程插入mysql
查看>>
数据库索引学习相关资料汇总
查看>>
equals和hashcode详解
查看>>
简单使用jumpserver
查看>>
利用碎片时间,TURBOMAIL飞邮手机客户端助你抓住每一个机遇
查看>>
execute、executeQuery和executeUpdate之间的区别
查看>>
equals()和hashCode()区别?
查看>>
开篇可好
查看>>
出现故障任何信息还是要亲自确认
查看>>
iOS开发-自定义控件的方式及注意
查看>>
Heartbeat + drbd 实现对mysql服务的高可用
查看>>
趣味学习:一篇文章读懂三层交换机【新任帮主】
查看>>
Hadoop源代码分析 - MapReduce(转载)
查看>>
C#摄像头实现拍照功能的简单代码示例
查看>>