ops::TFRecordReader类使用
TFRecordReader类内部使用protobuf,可以在不同的语言之间交换数据。主要是用于机器学习的特征数据。但是这个recordreader在读取时,一次只能读取一个数据,也就是per record的读取。Python 侧构造TFRecord的代码如下,
test_create_tf_record.py
import tensorflow as tfimport numpy as npimport jsontfrecord_filename = '/tmp/train.tfrecord'# 创建.tfrecord文件,准备写入writer = tf.compat.v1.python_io.TFRecordWriter(tfrecord_filename)for i in range(100): img_raw = np.random.random_integers(0,255,size=(30, 7)) # 创建30*7,取值在0-255之间随机数组 img_raw = bytes(json.dumps(img_raw.tolist()), "utf-8") example = tf.compat.v1.train.Example(features=tf.train.Features( feature={ # Int64List储存int数据 'label': tf.train.Feature(int64_list = tf.train.Int64List(value=)), # 储存byte二进制数据 'img_raw':tf.train.Feature(bytes_list = tf.train.BytesList(value=)) })) # 序列化过程 writer.write(example.SerializeToString()) writer.close()
程序目录结构如下,
image.png
程序代码如下,
CMakeLists.txt
cmake_minimum_required(VERSION 3.3)project(test_parse_ops)set(ENV{PKG_CONFIG_PATH} "$ENV{PKG_CONFIG_PATH}:/usr/local/lib/pkgconfig/")set(CMAKE_CXX_STANDARD 17)add_definitions(-g)include(${CMAKE_BINARY_DIR}/conanbuildinfo.cmake)conan_basic_setup()find_package(TensorflowCC REQUIRED)find_package(PkgConfig REQUIRED)pkg_search_module(PKG_PARQUET REQUIRED IMPORTED_TARGET parquet)pkg_search_module(PKG_ARROW REQUIRED IMPORTED_TARGET arrow)pkg_search_module(PKG_ARROW_COMPUTE REQUIRED IMPORTED_TARGET arrow-compute)pkg_search_module(PKG_ARROW_CSV REQUIRED IMPORTED_TARGET arrow-csv)pkg_search_module(PKG_ARROW_DATASET REQUIRED IMPORTED_TARGET arrow-dataset)pkg_search_module(PKG_ARROW_FS REQUIRED IMPORTED_TARGET arrow-filesystem)pkg_search_module(PKG_ARROW_JSON REQUIRED IMPORTED_TARGET arrow-json)set(ARROW_INCLUDE_DIRS ${PKG_PARQUET_INCLUDE_DIRS} ${PKG_ARROW_INCLUDE_DIRS} ${PKG_ARROW_COMPUTE_INCLUDE_DIRS} ${PKG_ARROW_CSV_INCLUDE_DIRS} ${PKG_ARROW_DATASET_INCLUDE_DIRS} ${PKG_ARROW_FS_INCLUDE_DIRS} ${PKG_ARROW_JSON_INCLUDE_DIRS})set(INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../include ${ARROW_INCLUDE_DIRS})set(ARROW_LIBS PkgConfig::PKG_PARQUET PkgConfig::PKG_ARROW PkgConfig::PKG_ARROW_COMPUTE PkgConfig::PKG_ARROW_CSV PkgConfig::PKG_ARROW_DATASET PkgConfig::PKG_ARROW_FS PkgConfig::PKG_ARROW_JSON)include_directories(${INCLUDE_DIRS})file( GLOB test_file_list ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) file( GLOB APP_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/tensor_testutil.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/queue_runner.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/coordinator.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/status.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/death_handler/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/df/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/arr_/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/img_util/impl/*.cpp)add_library(${PROJECT_NAME}_lib SHARED ${APP_SOURCES})target_link_libraries(${PROJECT_NAME}_lib PUBLIC ${CONAN_LIBS} TensorflowCC::TensorflowCC ${ARROW_LIBS})foreach( test_file ${test_file_list} ) file(RELATIVE_PATH filename ${CMAKE_CURRENT_SOURCE_DIR} ${test_file}) string(REPLACE ".cpp" "" file ${filename}) add_executable(${file}${test_file}) target_link_libraries(${file} PUBLIC ${PROJECT_NAME}_lib)endforeach( test_file ${test_file_list})
tf_record_reader_test.cpp
#include <string>#include <vector>#include <array>#include <fstream>#include <glog/logging.h>#include "tensorflow/core/platform/test.h"#include "death_handler/death_handler.h"#include "tf_/tensor_testutil.h"#include "tensorflow/cc/framework/scope.h"#include "tensorflow/cc/client/client_session.h"#include "tensorflow/cc/ops/standard_ops.h"#include "tensorflow/cc/training/coordinator.h"#include "tensorflow/core/framework/graph.pb.h"#include "tensorflow/core/framework/tensor.h"#include "tensorflow/core/framework/tensor_shape.h"#include "tensorflow/core/framework/types.pb.h"#include "tensorflow/core/lib/core/notification.h"#include "tensorflow/core/lib/core/status_test_util.h"#include "tensorflow/core/platform/env.h"#include "tensorflow/core/platform/test.h"#include "tensorflow/core/protobuf/error_codes.pb.h"#include "tensorflow/core/protobuf/queue_runner.pb.h"#include "tensorflow/core/public/session.h"#include "tf_/queue_runner.h"using namespace tensorflow;int main(int argc, char** argv) { FLAGS_log_dir = "./"; FLAGS_alsologtostderr = true; // 日志级别 INFO, WARNING, ERROR, FATAL 的值分别为0、1、2、3 FLAGS_minloglevel = 0; Debug::DeathHandler dh; google::InitGoogleLogging("./logs.log"); ::testing::InitGoogleTest(&argc, argv); int ret = RUN_ALL_TESTS(); return ret;}TEST(TfArrayOpsTests, FixLenRecordReader) { // 定长读取文本文件 // https://www.tensorflow.org/versions/r2.6/api_docs/cc/class/tensorflow/ops/fixed-length-record-reader#classtensorflow_1_1ops_1_1_fixed_length_record_reader_1aa6ad72f08d89016af3043f72912d11eb Scope root = Scope::NewRootScope(); auto attrs = ops::FIFOQueue::Capacity(200); auto queue_ = ops::FIFOQueue(root.WithOpName("queue"), {DT_STRING}, attrs); auto tensor_ = ops::Const(root, {"/cppwork/_tf/test_parse_ops/data/train.tfrecord"}); auto enque_ = ops::QueueEnqueueMany(root.WithOpName("enque"), queue_, {tensor_}); auto close_ = ops::QueueClose(root.WithOpName("close"), queue_); auto reader = ops::TFRecordReader(root); auto read_res = ops::ReaderRead(root.WithOpName("rec_read"), reader, queue_); Tensor dense_def0(DT_STRING, {1}); Tensor dense_def1(DT_INT64, {1}); // 1. 这个函数很坑,它读取TFRecordReader对象的输出值,read_res.value, 接收两个和输出值类型相同的默认输出值 // dense_def0 , dense_def1,当然你的数据如果有三个Feature,这里就三个默认值,注意默认值需要与输出值类型相同 // 2. "img_raw", "label" 是Python侧命名的标签 // 3. {1}, {1} 是代表单个特征的大小 // 4. 注意这里支持的类型只有 DT_INT64, DT_STRING和 DT_FLOAT64,其中DT_STRING在Python侧表现为bytearray auto parse_op = ops::ParseSingleExample(root.WithOpName("parse_op"), {read_res.value}, {dense_def0, dense_def1}, 0, {}, {"img_raw", "label"}, {}, {{1}, {1}}); SessionOptions options; std::unique_ptr<Session> session(NewSession(options)); GraphDef graph_def; TF_EXPECT_OK(root.ToGraphDef(&graph_def)); session->Create(graph_def); QueueRunnerDef queue_runner_def = test::BuildQueueRunnerDef("queue", {"enque"}, "close", "", {tensorflow::error::CANCELLED}); std::unique_ptr<QueueRunner> qr; TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); TF_CHECK_OK(qr->Start(session.get())); TF_EXPECT_OK(session->Run({}, {}, {"enque"}, nullptr)); TF_EXPECT_OK(session->Run({}, {}, {"close"}, nullptr)); std::vector<Tensor> outputs; // 这里 Run 一次会获取一个特征 for(int i=0; i< 100; ++i) { std::vector<Tensor> outputs_res; session->Run({}, {parse_op.dense_values.name(), parse_op.dense_values.name()}, {}, &outputs_res); std::cout << outputs_res.DebugString() << "\n"; std::cout << outputs_res.DebugString() << "\n"; auto res = test::GetTensorValue<int64>(outputs_res); ASSERT_EQ(i, res); } TF_EXPECT_OK(qr->Join());}
程序输出如下,
image.png
页:
[1]