找回密码
 立即注册
查看: 235|回复: 0

【PyTorch深度学习项目实战100例】—— 基于pytorch使用LSTM实现新闻本文分类任务 | 第9例

[复制链接]
发表于 2022-12-14 13:26 | 显示全部楼层 |阅读模式
前言

大家好,我是阿光。

本专栏整理了《PyTorch深度学习项目实战100例》,内包含了各种不同的深度学习项目,包含项目原理以及源码,每一个项目实例都附带有完整的代码+数据集。

正在更新中~

我的项目环境:
    平台:Windows10语言环境:python3.7编译器:PyCharmPyTorch版本:1.8.1

项目专栏:【PyTorch深度学习项目实战100例】
<hr>一、使用LSTM实现新闻本文分类任务

本文主要用LSTM循环神经网络拟合微调实现一个包含十五个类别的新闻文本分类任务,主要是对新闻内容进行特征抽取,获取语义分析来实现分类任务。

在这里插入图片描述

二、数据集介绍

整个数据集整合划分出15个候选分类类别:法治、国际、国内、健康、教育、经济、军事、科技、农经、三农、人物、社会、生活、书画、文娱的文本数据。

数据总共有4482条新闻纪录,字段分别为:标题、标题链接、新闻内容、关键词、发布时间、标签、新闻采集时间。

在这里插入图片描述

三、项目实现思路

为了体会LSTM的作用,并没有对原始数据进行高纬度的建模,只是使用了新闻内容的这个特征,没有对其他特征进行建模,由于新闻内容是中文文本数据,所以我们需要对其进行向量化,转成数值型数据然后送入到网络模型。

但是对于文本数据来讲,如果只是单纯使用Embedding进行嵌入的话,完全没有考虑到语义那种前后联系,会导致模型训练效果较差。

所以我们本项目中使用了LSTM这种网络进行捕捉语义信息,因为LSTM是一个循环神经网络,它可以将上个学习步的细胞信息传递给下个细胞,这样就会把前面出现的语句信息与当前输入进行结合,来预测之后出现的语句。

首先对于输入数据我们将其进行序列化,将一句话中的所有字转成对应的索引号,如果长度不足,我们需要使用0进行填充,保证输入到网络模型中的向量长度一致,然后需要使用Embedding进行将其进行嵌入,获得每个字的嵌入连续型向量,此处也可以使用one-hot编码,但是这会导致维度爆炸,以及矩阵稀疏问题,之后把生成的嵌入向量导入到LSTM层中,然后,因为这个时间片已经保存了整个语句的语义信息
四、网络结构

项目中使用的模型是LSTM,在模型中我们定义了三个组件,分别是embedding层,lstm层和全连接层。

在这里插入图片描述

    Embedding层:将每个词生成对应的嵌入向量,就是利用一个连续型向量来表示每个词Lstm层:提取语句中的语义信息Linear层:将结果映射成2大小用于二分类,即正反面的概率

注意:在LSTM网络中返回的值为最后一个时间片的输出,而不是将整个output全部输出,因为我们是需要捕捉整个语句的语义信息,并不是获得特定时间片的数据。
五、语句测试

    首先需要对我们待测试的语句进行转为序号编码如果序列长度不足,使用0进行填充加载训练好的模型加载数据映射字典,获得结果
完整源码

【PyTorch深度学习项目实战100例】—— 基于pytorch使用LSTM实现新闻本文分类任务 | 第9例_咕 嘟的博客-CSDN博客_pytorch项目

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

×
懒得打字嘛,点击右侧快捷回复 【右侧内容,后台自定义】
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Unity开发者联盟 ( 粤ICP备20003399号 )

GMT+8, 2024-11-24 15:40 , Processed in 0.119314 second(s), 26 queries .

Powered by Discuz! X3.5 Licensed

© 2001-2024 Discuz! Team.

快速回复 返回顶部 返回列表