Baste 发表于 2022-3-26 11:45

ops::TFRecordReader类使用

TFRecordReader类内部使用protobuf,可以在不同的语言之间交换数据。主要是用于机器学习的特征数据。但是这个recordreader在读取时,一次只能读取一个数据,也就是per record的读取。
Python 侧构造TFRecord的代码如下,
test_create_tf_record.py
import tensorflow as tfimport numpy as npimport jsontfrecord_filename = '/tmp/train.tfrecord'# 创建.tfrecord文件,准备写入writer = tf.compat.v1.python_io.TFRecordWriter(tfrecord_filename)for i in range(100):    img_raw = np.random.random_integers(0,255,size=(30, 7)) # 创建30*7,取值在0-255之间随机数组    img_raw = bytes(json.dumps(img_raw.tolist()), "utf-8")    example = tf.compat.v1.train.Example(features=tf.train.Features(            feature={            # Int64List储存int数据            'label': tf.train.Feature(int64_list = tf.train.Int64List(value=)),             # 储存byte二进制数据            'img_raw':tf.train.Feature(bytes_list = tf.train.BytesList(value=))            }))    # 序列化过程    writer.write(example.SerializeToString()) writer.close()
程序目录结构如下,


image.png

程序代码如下,
CMakeLists.txt
cmake_minimum_required(VERSION 3.3)project(test_parse_ops)set(ENV{PKG_CONFIG_PATH} "$ENV{PKG_CONFIG_PATH}:/usr/local/lib/pkgconfig/")set(CMAKE_CXX_STANDARD 17)add_definitions(-g)include(${CMAKE_BINARY_DIR}/conanbuildinfo.cmake)conan_basic_setup()find_package(TensorflowCC REQUIRED)find_package(PkgConfig REQUIRED)pkg_search_module(PKG_PARQUET REQUIRED IMPORTED_TARGET parquet)pkg_search_module(PKG_ARROW REQUIRED IMPORTED_TARGET arrow)pkg_search_module(PKG_ARROW_COMPUTE REQUIRED IMPORTED_TARGET arrow-compute)pkg_search_module(PKG_ARROW_CSV REQUIRED IMPORTED_TARGET arrow-csv)pkg_search_module(PKG_ARROW_DATASET REQUIRED IMPORTED_TARGET arrow-dataset)pkg_search_module(PKG_ARROW_FS REQUIRED IMPORTED_TARGET arrow-filesystem)pkg_search_module(PKG_ARROW_JSON REQUIRED IMPORTED_TARGET arrow-json)set(ARROW_INCLUDE_DIRS ${PKG_PARQUET_INCLUDE_DIRS} ${PKG_ARROW_INCLUDE_DIRS} ${PKG_ARROW_COMPUTE_INCLUDE_DIRS} ${PKG_ARROW_CSV_INCLUDE_DIRS} ${PKG_ARROW_DATASET_INCLUDE_DIRS} ${PKG_ARROW_FS_INCLUDE_DIRS} ${PKG_ARROW_JSON_INCLUDE_DIRS})set(INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../include ${ARROW_INCLUDE_DIRS})set(ARROW_LIBS PkgConfig::PKG_PARQUET PkgConfig::PKG_ARROW PkgConfig::PKG_ARROW_COMPUTE PkgConfig::PKG_ARROW_CSV PkgConfig::PKG_ARROW_DATASET PkgConfig::PKG_ARROW_FS PkgConfig::PKG_ARROW_JSON)include_directories(${INCLUDE_DIRS})file( GLOB test_file_list ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) file( GLOB APP_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/tensor_testutil.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/queue_runner.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/coordinator.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/status.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/death_handler/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/df/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/arr_/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/img_util/impl/*.cpp)add_library(${PROJECT_NAME}_lib SHARED ${APP_SOURCES})target_link_libraries(${PROJECT_NAME}_lib PUBLIC ${CONAN_LIBS} TensorflowCC::TensorflowCC ${ARROW_LIBS})foreach( test_file ${test_file_list} )    file(RELATIVE_PATH filename ${CMAKE_CURRENT_SOURCE_DIR} ${test_file})    string(REPLACE ".cpp" "" file ${filename})    add_executable(${file}${test_file})    target_link_libraries(${file} PUBLIC ${PROJECT_NAME}_lib)endforeach( test_file ${test_file_list})
tf_record_reader_test.cpp
#include <string>#include <vector>#include <array>#include <fstream>#include <glog/logging.h>#include "tensorflow/core/platform/test.h"#include "death_handler/death_handler.h"#include "tf_/tensor_testutil.h"#include "tensorflow/cc/framework/scope.h"#include "tensorflow/cc/client/client_session.h"#include "tensorflow/cc/ops/standard_ops.h"#include "tensorflow/cc/training/coordinator.h"#include "tensorflow/core/framework/graph.pb.h"#include "tensorflow/core/framework/tensor.h"#include "tensorflow/core/framework/tensor_shape.h"#include "tensorflow/core/framework/types.pb.h"#include "tensorflow/core/lib/core/notification.h"#include "tensorflow/core/lib/core/status_test_util.h"#include "tensorflow/core/platform/env.h"#include "tensorflow/core/platform/test.h"#include "tensorflow/core/protobuf/error_codes.pb.h"#include "tensorflow/core/protobuf/queue_runner.pb.h"#include "tensorflow/core/public/session.h"#include "tf_/queue_runner.h"using namespace tensorflow;int main(int argc, char** argv) {    FLAGS_log_dir = "./";    FLAGS_alsologtostderr = true;    // 日志级别 INFO, WARNING, ERROR, FATAL 的值分别为0、1、2、3    FLAGS_minloglevel = 0;    Debug::DeathHandler dh;    google::InitGoogleLogging("./logs.log");    ::testing::InitGoogleTest(&argc, argv);    int ret = RUN_ALL_TESTS();    return ret;}TEST(TfArrayOpsTests, FixLenRecordReader) {    // 定长读取文本文件    // https://www.tensorflow.org/versions/r2.6/api_docs/cc/class/tensorflow/ops/fixed-length-record-reader#classtensorflow_1_1ops_1_1_fixed_length_record_reader_1aa6ad72f08d89016af3043f72912d11eb    Scope root = Scope::NewRootScope();      auto attrs = ops::FIFOQueue::Capacity(200);    auto queue_ = ops::FIFOQueue(root.WithOpName("queue"), {DT_STRING}, attrs);    auto tensor_ = ops::Const(root, {"/cppwork/_tf/test_parse_ops/data/train.tfrecord"});    auto enque_ = ops::QueueEnqueueMany(root.WithOpName("enque"), queue_, {tensor_});    auto close_ = ops::QueueClose(root.WithOpName("close"), queue_);    auto reader = ops::TFRecordReader(root);    auto read_res = ops::ReaderRead(root.WithOpName("rec_read"), reader, queue_);    Tensor dense_def0(DT_STRING, {1});    Tensor dense_def1(DT_INT64, {1});    // 1. 这个函数很坑,它读取TFRecordReader对象的输出值,read_res.value, 接收两个和输出值类型相同的默认输出值    // dense_def0 , dense_def1,当然你的数据如果有三个Feature,这里就三个默认值,注意默认值需要与输出值类型相同    // 2. "img_raw", "label" 是Python侧命名的标签    // 3. {1}, {1} 是代表单个特征的大小    // 4. 注意这里支持的类型只有 DT_INT64, DT_STRING和 DT_FLOAT64,其中DT_STRING在Python侧表现为bytearray    auto parse_op = ops::ParseSingleExample(root.WithOpName("parse_op"), {read_res.value}, {dense_def0, dense_def1}, 0, {}, {"img_raw", "label"}, {}, {{1}, {1}});      SessionOptions options;    std::unique_ptr<Session> session(NewSession(options));    GraphDef graph_def;    TF_EXPECT_OK(root.ToGraphDef(&graph_def));    session->Create(graph_def);    QueueRunnerDef queue_runner_def =      test::BuildQueueRunnerDef("queue", {"enque"}, "close", "", {tensorflow::error::CANCELLED});    std::unique_ptr<QueueRunner> qr;    TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));    TF_CHECK_OK(qr->Start(session.get()));    TF_EXPECT_OK(session->Run({}, {}, {"enque"}, nullptr));    TF_EXPECT_OK(session->Run({}, {}, {"close"}, nullptr));    std::vector<Tensor> outputs;    // 这里 Run 一次会获取一个特征    for(int i=0; i< 100; ++i) {      std::vector<Tensor> outputs_res;      session->Run({}, {parse_op.dense_values.name(), parse_op.dense_values.name()}, {}, &outputs_res);      std::cout << outputs_res.DebugString() << "\n";      std::cout << outputs_res.DebugString() << "\n";      auto res = test::GetTensorValue<int64>(outputs_res);      ASSERT_EQ(i, res);    }    TF_EXPECT_OK(qr->Join());}
程序输出如下,


image.png
页: [1]
查看完整版本: ops::TFRecordReader类使用