本文共 3850 字,大约阅读时间需要 12 分钟。
首先,下载所需的TensorFlow模型文件。解压后,将文件放在合适的目录下。确保下载的文件格式为.tar.gz,避免使用其他格式。
下载相关的 Protobuf 源文件,确保选择的是带有 win64 的版本(即 protoc-xxx-win64.zip)。解压后,将 bin 目录下的 protoc.exe 文件复制到 C:\Windows\System32 文件夹。
使用 PowerShell(注意:PowerShell 是必要的,使用 cmd 会导致错误)进入 research 文件夹,运行以下命令:
Get-ChildItem object_detection/protos/*.proto | Resolve-Path -Relative | % { protoc $_ --python_out=. } 成功后,protos 文件夹下会生成对应的 Python 源代码文件。
在 Anaconda\Lib\site-packages 文件夹中,创建名为 tensorflow_model.pth 的路径文件。将所需模块文件目录名称作为文件内容,以 .pth 为扩展名保存。
进入 research 文件夹,使用以下命令安装项目:
python setup.py buildpython setup.py install
创建 object_detection_demo.py 文件(将模型文件链接替换为实际路径):
import numpy as npimport osimport urllib.requestimport tarfileimport tensorflow as tfimport zipfileimport matplotlib.pyplot as pltimport cv2# Matplotlib 启用 XWindows backendplt.use('Agg')# 模型配置MODEL_NAME = 'ssd_mobilenet_v1_coco_2017_11_17'PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')NUM_CLASSES = 90# 下载并解压模型if not os.path.exists(PATH_TO_CKPT): print('正在下载模型...') opener = urllib.request.URLopener() opener.retrieve('http://download.tensorflow.org/models/object_detection/' + MODEL_NAME + '.tar.gz', MODEL_NAME + '.tar.gz') print('正在解压...') tar_file = tarfile.open(MODEL_NAME + '.tar.gz') for file in tar_file.getmembers(): if 'frozen_inference_graph.pb' in os.path.basename(file.name): tar_file.extract(file, os.getcwd())else: print('模型已下载。')# 加载模型print('正在加载模型...')detection_graph = tf.Graph()with detection_graph.as_default(): od_graph_def = tf.GraphDef() with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: serialized_graph = fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def, name='')# 加载标签文件print('正在加载标签文件...')label_map = label_map_util.load_labelmap(PATH_TO_LABELS)categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)category_index = label_map_util.create_category_index(categories)# 测试图片路径(根据实际情况修改)TEST_IMAGE_PATH = 'test_images/image1.jpg'# 帮助函数def load_image_into_numpy_array(image): (im_width, im_height) = image.size return np.array(image.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8)# 创建输出图像大小IMAGE_SIZE = (12, 8)print('正在进行检测...')config = tf.ConfigProto()config.gpu_options.allow_growth = Truewith detection_graph.as_default(): with tf.Session(graph=detection_graph, config=config) as sess: print(TEST_IMAGE_PATH) image = Image.open(TEST_IMAGE_PATH) image_np = load_image_into_numpy_array(image) image_np_expanded = np.expand_dims(image_np, axis=0) image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') boxes = detection_graph.get_tensor_by_name('detection_boxes:0') scores = detection_graph.get_tensor_by_name('detection_scores:0') classes = detection_graph.get_tensor_by_name('detection_classes:0') num_detections = detection_graph.get_tensor_by_name('num_detections:0') (boxes, scores, classes, num_detections) = sess.run( [boxes, scores, classes, num_detections], feed_dict={image_tensor: image_np_expanded} ) vis_util.visualize_boxes_and_labels_on_image_array( image_np, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores), category_index, use_normalized_coordinates=True, line_thickness=8 ) plt.figure(figsize=IMAGE_SIZE, dpi=300) plt.imshow(image_np) plt.savefig(TEST_IMAGE_PATH.split('.')[0] + '_labeled.jpg') 运行后,object_detection 文件夹下的 test_images 文件会多出一张名为 image1_labeled.jpg 的标记图片,确认配置成功。
转载地址:http://fynfk.baihongyu.com/