请选择 进入手机版 | 继续访问电脑版

大蛇智能

 找回密码
 立即注册

扫一扫,访问微社区

搜索
热搜: 活动 交友 discuz
查看: 3791|回复: 7

连载一:用slim调用PNASNet模型(内附源码)

[复制链接]

127

主题

308

帖子

989

积分

管理员

Rank: 9Rank: 9Rank: 9

积分
989
发表于 2018-10-25 12:49:32 | 显示全部楼层 |阅读模式

使用AI模型来识别图像是桌子、猫、狗,还是其他

本章将演示一个应用AI模型进行图像识别的例子。通过该实例能够让读者真真切切的感受到AI的强大,及使用模型的操作过程。

案例描述
通过代码载入现有模型,对任意图片进行分类识别,观察识别结果。
本案使用的是在ImgNet数据集上训练好的PNASNet模型。PNASNet模型是目前最优秀的图片识别模型之一。该模型在ImgNet数据集上训练后,可以识别1000种类别的图片。要完成该案例,需要先下载TensorFlow中的models模块及对应的与训练模型。下面就来详细介绍。
代码环境及模型准备
为了使读者能够快速完成该实例,直观上感受到模型的识别能力,可以直接使用本书配套的资源。并将其放到代码的同级目录下即可。
如果想体验下从零开始手动搭建,也可以按照下面的方法准备代码环境及预编译模型。
1. 下载TensorFlow models模块
TensorFlow models模块中包含了使用TensorFlow框架完成的各种不同模型,可以直接拿来使用。在TensorFlow models模块中进行二次开发,可以使AI项目开发变得简单快捷。来到以下网址:
https://github.com/tensorflow/models/
可以通过git 将代码clone下来,也可以手动下载(具体操作见《深度学习之TensorFlow:入门、原理与进阶实战》一书的8.5.2节)。
2. 部署TensorFlow slim模块
解压之后,将其中\models-master\research路径下的slim文件夹(如图1),复制到本地代码的同级路径下。

1.png
图1 slim代码库路径

slim库又叫做TF-slim,是TensorFlow 1.0之后推出的一个新的轻量级高级API接口。将很多常见TensorFlow函数做了二次封装,使代码变得更加简洁。
在TF-slim模块里面同时提供了大量用TF-slim写好的网络模型结构代码,以及用该代码训练出的模型文件。本例中就是使用TF-slim模块中训练好的PNASNet模型文件。

3. 下载PNASNet模型
访问如下网站,可以下载训练好的PNASNet模型:
https://github.com/tensorflow/models/tree/master/research/slim
该链接打开后,可以找到“pnasnet-5_large_2017_12_13.tar.gz”的下载地址,如图2。

2.jpg
图2 PNASNet模型下载页面

下载完后,将其解压,会得到如下图3中的文件结构。
3.png
图3 PNASNet模型文件

将整个pnasnet-5_large_2017_12_13文件夹放到本地代码的同级目录下。在使用时,只需要指定好模型的路径:“pnasnet-5_large_2017_12_13”,系统便会自动加载模型里面的文件及内容。
注意:
在图3-2中,可以看到,出来本实例所用的PNASNet模型外,还有好多其他的模型。其中倒数第二行的mobilenet_v2_1.0_224.tgz模型也是比较常用的,该模型体积小、运算快,常用于在移动设备。

4. 准备ImgNet数据集标签
由于本例中使用的PNASNet预训练模型是在ImgNet数据集上训练好的模型,在使用该模型分类是,还需要有与其对应的标签文件。slim中已经将获得标签文件的操作直接封装到了代码里,直接调用即可。由于标签文件是英文分类,读起来不太直观。这里提供了一个翻译好的中文标签分类文件“中文标签.csv”。也在书籍同步的配套资源中。
前面4项都准备好后,整体的目录结构如图4所示。

4.png
图4 实例1文件结构

在图4中,会看到还有三个图片文件“72.jpg”、“hy.jpg”、“ps.jpg”,这三个文件是用于测试使用的图片,读者可以替换为自己所要识别的文件。

代码实现:初始化环境变量,并载入ImgNet标签
首先将本地的slim作为引用库载入到系统的环境变量里。接着将ImgNet标签载入并显示出来。
1 import sys                                                 #初始化环境变量
2 nets_path = r'slim'
3 if nets_path not in sys.path:
4     sys.path.insert(0,nets_path)
5 else:
6     print('already add slim')
7
8 import tensorflow as tf                                   #引入头文件
9 from PIL import Image
10 from matplotlib import pyplot as plt
11 from nets.nasnet import pnasnet
12 import numpy as np
13 from datasets import imagenet
14 slim = tf.contrib.slim
15
16 tf.reset_default_graph()                       
17
18 image_size = pnasnet.build_pnasnet_large.default_image_size       #获得图片输入尺寸
19 labels = imagenet.create_readable_names_for_imagenet_labels()     #获得数据集标签
20 print(len(labels),labels)                                             #显示输出标签
21
22 def getone(onestr):
23    return onestr.replace(',',' ')
24
25 with open('中文标签.csv','r+') as f:                             #打开文件               
26    labels =list( map(getone,list(f))  )
27    print(len(labels),type(labels),labels[:5])
使用AI模型来识别图像

代码中提供了英文与中文的两种标签。在实际应用中使用了中文的标签。程序运行后输出结果如下:

1001 {0: 'background', 1: 'tench, Tinca tinca', 2: 'goldfish, Carassius auratus', 3: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias', 4: 'tiger shark, Galeocerdo cuvieri', 5: 'hammerhead, hammerhead shark',……,994: 'gyromitra', 995: 'stinkhorn, carrion fungus', 996: 'earthstar', 997: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa', 998: 'bolete', 999: 'ear, spike, capitulum', 1000: 'toilet tissue, toilet paper, bathroom tissue'}
1001 <class 'list'> ['背景known   \n', '丁鲷     \n', '金鱼     \n', '大白鲨     \n', '虎鲨     \n']

一共输出了两行,第一行为英文标签,第二行为中文标签。

代码实现:定义网络结构
通过代码,定义了占位符input_imgs,用于输入待识别的图片。接着定义网络节点end_points,对接预训练模型的输出节点。end_points是一个字典,里面Predictions对应的值就是最终的输出结果。该值中放置着1000个元素的数组,代表预测图片在这1000个分类中的概率。通过tf.argmax函数对最终结果进行转化,得到数组中最大的那个数的索引,便是该图片的分类。
28 sample_images = ['hy.jpg', 'ps.jpg','72.jpg']                   #定义待测试图片路径
29
30 input_imgs = tf.placeholder(tf.float32, [None, image_size,image_size,3]) #定义占位符
31
32 x1 = 2 *( input_imgs / 255.0)-1.0                                 #归一化图片
33
34 arg_scope = pnasnet.pnasnet_large_arg_scope()                  #获得模型命名空间
35 with slim.arg_scope(arg_scope):
36    logits, end_points = pnasnet.build_pnasnet_large(x1,num_classes = 1001, is_training=False)   
37    prob = end_points['Predictions']
38    y = tf.argmax(prob,axis = 1)                                  #获得结果的输出节点
使用AI模型来识别图像(续)

在34行代码中的arg_scope是命名空间的意思。在TensorFlow中相同名称的不同张量是通过命名空间来划分的。关于命名空间的更多知识可以参考《深度学习之TensorFlow:入门、原理与进阶实战》一书的4.3节。
代码中第28行指定了待识别图片的名称。如果想识别自己的图片,直接修改该行代码中的图片名称即可。

代码实现:载入模型进行识别
指定好要加载的预训练模型,建立会话进行图片识别。
39 checkpoint_file = r'pnasnet-5_large_2017_12_13\model.ckpt'       #定义模型路径
40 saver = tf.train.Saver()                                                #定义saver,用于加载模型
41 with tf.Session() as sess:                                              #建立会话
42    saver.restore(sess, checkpoint_file)                            #载入模型
43
44    def preimg(img):                                    #定义图片预处理函数
45        ch = 3
46        if img.mode=='RGBA':                            #兼容RGBA图片
47            ch = 4
48
49        imgnp = np.asarray(img.resize((image_size,image_size)),
50                          dtype=np.float32).reshape(image_size,image_size,ch)
51        return imgnp[:,:,:3]
52
53    #获得原始图片与预处理图片
54    batchImg = [ preimg( Image.open(imgfilename) ) for imgfilename in sample_images ]
55    orgImg = [  Image.open(imgfilename)  for imgfilename in sample_images ]
56
57    yv,img_norm = sess.run([y,x1], feed_dict={input_imgs: batchImg})    #输入到模型
58
59    print(yv,np.shape(yv))                                              #显示输出结果         
60    def showresult(yy,img_norm,img_org):                            #定义显示图片函数
61        plt.figure()  
62        p1 = plt.subplot(121)
63        p2 = plt.subplot(122)
64        p1.imshow(img_org)                                        #显示图片
65        p1.axis('off')
66        p1.set_title("organization image")
67
68        p2.imshow(img_norm)                                        #显示图片
69        p2.axis('off')
70        p2.set_title("input image")  
71
72        plt.show()
73        print(yy,labels[yy])
74
75    for yy,img1,img2 in zip(yv,batchImg,orgImg):                    #显示每条结果及图片
76        showresult(yy,img1,img2)
使用AI模型来识别图像(续)

在TensorFlow中,模型运行时会有个图的概念。在本例中,原始的网络结构会在静态图中定义好,接着通过建立一个会话(代码41行)让当前代码与静态图连接起来。调用sess中的run函数将数据输入到静态图中,并返回结果,从而实现图片的识别。
在模型识别之前,所有的图片都要统一成固定大小的尺寸(代码49行),并进行归一化(代码32行)。这个过程叫做图片预处理。经过预处理后的图片放到模型中,才能够得到准确的结果。
代码运行后,输出结果如下:

5.jpg

结果一共显示了3幅图,3段文字。每幅图片下一行的文字,为模型识别出来的结果。在每幅图中,左侧为原始图片,右侧为预处理后的图片。

连载1代码.rar (1.23 MB, 下载次数: 185)
回复

使用道具 举报

0

主题

1

帖子

4

积分

新手上路

Rank: 1

积分
4
发表于 2019-6-27 10:55:31 | 显示全部楼层
x1 = 2 *( input_imgs / 255.0)-1.0 ,一般的归一化不是直接除以255.0吗?这里的乘法和减法是出于什么考虑的?
回复

使用道具 举报

127

主题

308

帖子

989

积分

管理员

Rank: 9Rank: 9Rank: 9

积分
989
 楼主| 发表于 2019-7-25 07:14:07 来自手机 | 显示全部楼层
你说的那种做法也可以。这种做法是直接将值域归一化到-1和1之间。比直接除255得到的0到1,值域更广
来自: 微社区
回复

使用道具 举报

127

主题

308

帖子

989

积分

管理员

Rank: 9Rank: 9Rank: 9

积分
989
 楼主| 发表于 2019-7-25 07:15:07 来自手机 | 显示全部楼层
类似这种归一化的方法很多。只要处理后的值域合规即可
来自: 微社区
回复

使用道具 举报

127

主题

308

帖子

989

积分

管理员

Rank: 9Rank: 9Rank: 9

积分
989
 楼主| 发表于 2019-8-1 05:36:01 | 显示全部楼层
微信群链接已经失效。请加微信:elexment  申请进入微信群
回复

使用道具 举报

1

主题

2

帖子

7

积分

新手上路

Rank: 1

积分
7
发表于 2019-10-3 10:00:57 | 显示全部楼层
运行后报错

Traceback (most recent call last):

  File "<ipython-input-44-dad7a676ac65>", line 1, in <module>
    runfile('E:/work/tensorflow/pic/res/TEST.py', wdir='E:/work/tensorflow/pic/res')

  File "d:\Users\MSI-PC\Anaconda3\lib\site-packages\spyder\utils\site\sitecustomize.py", line 710, in runfile
    execfile(filename, namespace)

  File "d:\Users\MSI-PC\Anaconda3\lib\site-packages\spyder\utils\site\sitecustomize.py", line 101, in execfile
    exec(compile(f.read(), filename, 'exec'), namespace)

  File "E:/work/tensorflow/pic/res/TEST.py", line 43, in <module>
    logits, end_points = pnasnet.build_pnasnet_large(x1,num_classes = 1001, is_training=False)

  File "slim\nets\nasnet\pnasnet.py", line 172, in build_pnasnet_large
    nasnet._update_hparams(hparams, is_training)

  File "slim\nets\nasnet\nasnet.py", line 116, in _update_hparams
    hparams.set_hparam('drop_path_keep_prob', 1.0)

AttributeError: 'HParams' object has no attribute 'set_hparam'
回复

使用道具 举报

0

主题

1

帖子

8

积分

新手上路

Rank: 1

积分
8
发表于 2020-1-10 08:59:26 | 显示全部楼层
为什么报错  OSError: [WinError 10013] 以一种访问权限不允许的方式做了一个访问套接字的尝试?  是需要修改爬虫的函数吗?
回复

使用道具 举报

127

主题

308

帖子

989

积分

管理员

Rank: 9Rank: 9Rank: 9

积分
989
 楼主| 发表于 2020-2-9 18:23:59 | 显示全部楼层
你访问不了国外网站。将19行注释掉即可
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

Archiver|手机版|小黑屋|大蛇智能 ( 京ICP备 18026897 号 )

GMT+8, 2020-2-27 10:33 , Processed in 0.047809 second(s), 27 queries .

Powered by Discuz! X3.4

© 2001-2017 Comsenz Inc.

快速回复 返回顶部 返回列表