当前位置: 首页 > news >正文

利用(Transfer Learning)迁移学习在IMDB数据上训练一个文本分类模型

1. 背景

有些场景下,开始的时候数据量很小,如果我们用一个几千条数据训练一个全新的深度机器学习的文本分类模型,效果不会很好。这个时候你有两种选择,1.用传统的机器学习训练,2.利用迁移学习在一个预训练的模型上训练。本博客教你怎么用tensorflow Hub和keras 在少量的数据上训练一个文本分类模型。

2. 实践

2.1. 下载IMDB 数据集,参考下面博客。

Imdb影评的数据集介绍与下载_imdb影评数据集-CSDN博客

2.2.  预处理数据

替换掉imdb目录 (imdb_raw_data_dir). 创建dataset目录。

import numpy as np
import os as osimport re
from sklearn.model_selection import train_test_splitvocab_size = 30000
maxlen = 200
imdb_raw_data_dir = "/Users/harry/Documents/apps/ml/aclImdb"
save_dir = "dataset"def get_data(datapath =r'D:\train_data\aclImdb\aclImdb\train' ):pos_files = os.listdir(datapath + '/pos')neg_files = os.listdir(datapath + '/neg')print(len(pos_files))print(len(neg_files))pos_all = []neg_all = []for pf, nf in zip(pos_files, neg_files):with open(datapath + '/pos' + '/' + pf, encoding='utf-8') as f:s = f.read()s = process(s)pos_all.append(s)with open(datapath + '/neg' + '/' + nf, encoding='utf-8') as f:s = f.read()s = process(s)neg_all.append(s)print(len(pos_all))# print(pos_all[0])print(len(neg_all))X_orig= np.array(pos_all + neg_all)# print(X_orig)Y_orig = np.array([1 for _ in range(len(pos_all))] + [0 for _ in range(len(neg_all))])print("X_orig:", X_orig.shape)print("Y_orig:", Y_orig.shape)return X_orig, Y_origdef generate_dataset():X_orig, Y_orig = get_data(imdb_raw_data_dir + r'/train')X_orig_test, Y_orig_test = get_data(imdb_raw_data_dir + r'/test')X_orig = np.concatenate([X_orig, X_orig_test])Y_orig = np.concatenate([Y_orig, Y_orig_test])X = X_origY = Y_orignp.random.seed = 1random_indexs = np.random.permutation(len(X))X = X[random_indexs]Y = Y[random_indexs]X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.3)print("X_train:", X_train.shape)print("y_train:", y_train.shape)print("X_test:", X_test.shape)print("y_test:", y_test.shape)np.savez(save_dir + '/train_test', X_train=X_train, y_train=y_train, X_test= X_test, y_test=y_test )def rm_tags(text):re_tag = re.compile(r'<[^>]+>')return re_tag.sub(' ', text)def clean_str(string):string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)string = re.sub(r"\'s", " \'s", string)  # it's -> it 'sstring = re.sub(r"\'ve", " \'ve", string) # I've -> I 'vestring = re.sub(r"n\'t", " n\'t", string) # doesn't -> does n'tstring = re.sub(r"\'re", " \'re", string) # you're -> you arestring = re.sub(r"\'d", " \'d", string)  # you'd -> you 'dstring = re.sub(r"\'ll", " \'ll", string) # you'll -> you 'llstring = re.sub(r"\'m", " \'m", string) # I'm -> I 'mstring = re.sub(r",", " , ", string)string = re.sub(r"!", " ! ", string)string = re.sub(r"\(", " \( ", string)string = re.sub(r"\)", " \) ", string)string = re.sub(r"\?", " \? ", string)string = re.sub(r"\s{2,}", " ", string)return string.strip().lower()def process(text):text = clean_str(text)text = rm_tags(text)#text = text.lower()return  textif __name__ == '__main__':generate_dataset()

执行完后,产生train_test.npz 文件

2.3.  训练模型

1. 取数据集

def get_dataset_to_train():train_test = np.load('dataset/train_test.npz', allow_pickle=True)x_train =  train_test['X_train']y_train = train_test['y_train']x_test =  train_test['X_test']y_test = train_test['y_test']return x_train, y_train, x_test, y_test

2. 创建模型

基于nnlm-en-dim50/2 预训练的文本嵌入向量,在模型外面加了两层全连接。

def get_model():hub_layer = hub.KerasLayer(embedding_url, input_shape=[], dtype=tf.string, trainable=True)# Build the modelmodel = Sequential([hub_layer,Dense(16, activation='relu'),Dropout(0.5),Dense(2, activation='softmax')])print(model.summary())model.compile(optimizer=keras.optimizers.Adam(),loss=keras.losses.SparseCategoricalCrossentropy(),metrics=[keras.metrics.SparseCategoricalAccuracy()])return model

还可以使用来自 TFHub 的许多其他预训练文本嵌入向量:

  • google/nnlm-en-dim128/2 - 基于与 google/nnlm-en-dim50/2 相同的数据并使用相同的 NNLM 架构进行训练,但具有更大的嵌入向量维度。更大维度的嵌入向量可以改进您的任务,但可能需要更长的时间来训练您的模型。
  • google/nnlm-en-dim128-with-normalization/2 - 与 google/nnlm-en-dim128/2 相同,但具有额外的文本归一化,例如移除标点符号。如果您的任务中的文本包含附加字符或标点符号,这会有所帮助。
  • google/universal-sentence-encoder/4 - 一个可产生 512 维嵌入向量的更大模型,使用深度平均网络 (DAN) 编码器训练。

还有很多!在 TFHub 上查找更多文本嵌入向量模型。

3. 评估你的模型

def evaluate_model(test_data, test_labels):model = load_trained_model()# Evaluate the modelresults = model.evaluate(test_data, test_labels, verbose=2)print("Test accuracy:", results[1])def load_trained_model():# model = get_model()# model.load_weights('./models/model_new1.h5')model = tf.keras.models.load_model('models_pb')return model

4. 测试几个例子

def predict(real_data):model  = load_trained_model()probabilities = model.predict([real_data]);print("probabilities :",probabilities)result =  get_label(probabilities)return resultdef get_label(probabilities):index = np.argmax(probabilities[0])print("index :" + str(index))result_str =  index_dic.get(str(index))# result_str = list(index_dic.keys())[list(index_dic.values()).index(index)]return result_strdef predict_my_module():# review = "I don't like it"# review = "this is bad movie "# review = "This is good movie"review = " this is terrible movie"# review = "This isn‘t great movie"# review = "i think this is bad movie"# review = "I'm not very disappoint for this movie"# review = "I'm not very disappoint for this movie"# review = "I am very happy for this movie"#neg:0 postive:1s = predict(review)print(s)if __name__ == '__main__':x_train, y_train, x_test, y_test = get_dataset_to_train()model = get_model()model = train(model, x_train, y_train, x_test, y_test)evaluate_model(x_test, y_test)predict_my_module()

完整代码

import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense, Dropout
import keras as keras
from keras.callbacks import EarlyStopping, ModelCheckpoint
import tensorflow_hub as hubembedding_url = "https://tfhub.dev/google/nnlm-en-dim50/2"index_dic = {"0":"negative", "1": "positive"}def get_dataset_to_train():train_test = np.load('dataset/train_test.npz', allow_pickle=True)x_train =  train_test['X_train']y_train = train_test['y_train']x_test =  train_test['X_test']y_test = train_test['y_test']return x_train, y_train, x_test, y_testdef get_model():hub_layer = hub.KerasLayer(embedding_url, input_shape=[], dtype=tf.string, trainable=True)# Build the modelmodel = Sequential([hub_layer,Dense(16, activation='relu'),Dropout(0.5),Dense(2, activation='softmax')])print(model.summary())model.compile(optimizer=keras.optimizers.Adam(),loss=keras.losses.SparseCategoricalCrossentropy(),metrics=[keras.metrics.SparseCategoricalAccuracy()])return modeldef train(model , train_data, train_labels, test_data, test_labels):# train_data, train_labels, test_data, test_labels = get_dataset_to_train()train_data = [tf.compat.as_str(tf.compat.as_bytes(str(x))) for x in train_data]test_data = [tf.compat.as_str(tf.compat.as_bytes(str(x))) for x in test_data]train_data = np.asarray(train_data)  # Convert to numpy arraytest_data = np.asarray(test_data)  # Convert to numpy arrayprint(train_data.shape, test_data.shape)early_stop = EarlyStopping(monitor='val_sparse_categorical_accuracy', patience=4, mode='max', verbose=1)# 定义ModelCheckpoint回调函数# checkpoint = ModelCheckpoint( './models/model_new1.h5', monitor='val_sparse_categorical_accuracy', save_best_only=True,#                              mode='max', verbose=1)checkpoint_pb = ModelCheckpoint(filepath="./models_pb/",  monitor='val_sparse_categorical_accuracy', save_weights_only=False, save_best_only=True)history = model.fit(train_data[:2000], train_labels[:2000], epochs=45, batch_size=45, validation_data=(test_data, test_labels), shuffle=True,verbose=1, callbacks=[early_stop, checkpoint_pb])print("history", history)return modeldef evaluate_model(test_data, test_labels):model = load_trained_model()# Evaluate the modelresults = model.evaluate(test_data, test_labels, verbose=2)print("Test accuracy:", results[1])def predict(real_data):model  = load_trained_model()probabilities = model.predict([real_data]);print("probabilities :",probabilities)result =  get_label(probabilities)return resultdef get_label(probabilities):index = np.argmax(probabilities[0])print("index :" + str(index))result_str =  index_dic.get(str(index))# result_str = list(index_dic.keys())[list(index_dic.values()).index(index)]return result_strdef load_trained_model():# model = get_model()# model.load_weights('./models/model_new1.h5')model = tf.keras.models.load_model('models_pb')return modeldef predict_my_module():# review = "I don't like it"# review = "this is bad movie "# review = "This is good movie"review = " this is terrible movie"# review = "This isn‘t great movie"# review = "i think this is bad movie"# review = "I'm not very disappoint for this movie"# review = "I'm not very disappoint for this movie"# review = "I am very happy for this movie"#neg:0 postive:1s = predict(review)print(s)if __name__ == '__main__':x_train, y_train, x_test, y_test = get_dataset_to_train()model = get_model()model = train(model, x_train, y_train, x_test, y_test)evaluate_model(x_test, y_test)predict_my_module()

相关文章:

利用(Transfer Learning)迁移学习在IMDB数据上训练一个文本分类模型

1. 背景 有些场景下&#xff0c;开始的时候数据量很小&#xff0c;如果我们用一个几千条数据训练一个全新的深度机器学习的文本分类模型&#xff0c;效果不会很好。这个时候你有两种选择&#xff0c;1.用传统的机器学习训练&#xff0c;2.利用迁移学习在一个预训练的模型上训练…...

pom.xml格式化快捷键

在软件开发和编程领域&#xff0c;"格式化"通常指的是将代码按照一定的规范和风格进行排列&#xff0c;以提高代码的可读性和维护性。格式化代码有助于使代码结构清晰、统一&#xff0c;并符合特定的编码规范。 格式化可以包括以下方面&#xff1a; 缩进&#xff1a…...

【短文】【踩坑】可以在Qt Designer给QTableWidge添加右键菜单吗?

2023年11月18日&#xff0c;周六上午 今天早上在网上找了好久都没找到教怎么在Qt Designer给QTableWidge添加右键菜单的文章 答案是&#xff1a;不可以 在Qt Designer中无法直接为QTableWidget添加右键菜单。 Qt Designer主要用于创建界面布局和设计&#xff0c;无法直接添加…...

Git常用配置

git log 美化输出 全局配置参数 git config --global alias.lm "log --no-merges --color --dateformat:%Y-%m-%d %H:%M:%S --authorghost --prettyformat:%Cred%h%Creset - %Cgreen(%cd)%C(yellow)%d%Cblue %s %C(bold blue)<%an>%Creset --abbrev-commit"…...

力扣每日一题-数位和相等数对的最大和-2023.11.18

力扣每日一题&#xff1a;数位和相等数对的最大和 开篇 这道每日一题还是挺需要思考的&#xff0c;我绕晕了好久&#xff0c;根据题解的提示才写出来。 题目链接:2342.数位和相等数对的最大和 题目描述 代码思路 1.创建一个数组存储每个数位的数的最大值&#xff0c;创建一…...

【win32_001】win32命名规、缩写、窗口

整数类型 bool类型 使用注意&#xff1a; 一般bool 的false0&#xff1b;true1 | 2 | …|n false是为0&#xff0c;true是非零 不建议这样用&#xff1a; if (result TRUE) // Wrong! 因为result不一定只返回1&#xff08;true&#xff09;&#xff0c;当返回2时&#xff0c…...

机器学习第8天:SVM分类

文章目录 机器学习专栏 介绍 特征缩放 示例代码 硬间隔与软间隔分类 主要代码 代码解释 非线性SVM分类 结语 机器学习专栏 机器学习_Nowl的博客-CSDN博客 介绍 作用&#xff1a;判别种类 原理&#xff1a;找出一个决策边界&#xff0c;判断数据所处区域来识别种类 简单…...

AI工具合集

网站&#xff1a;未来百科 | 为发现全球优质AI工具产品而生 (6aiq.com) 如今&#xff0c;AI技术涉及到了很多领域&#xff0c;比如去水印、一键抠图、图像处理、AI图像生成等等。站长之家之前也分享过一些&#xff0c;但是在网上要搜索找到它们还是费一些功夫。 今天发现了一…...

代码随想录算法训练营Day 54 || 392.判断子序列、115.不同的子序列

392.判断子序列 力扣题目链接(opens new window) 给定字符串 s 和 t &#xff0c;判断 s 是否为 t 的子序列。 字符串的一个子序列是原始字符串删除一些&#xff08;也可以不删除&#xff09;字符而不改变剩余字符相对位置形成的新字符串。&#xff08;例如&#xff0c;&quo…...

C 语言 gets()和puts()

C 语言 gets()和puts() gets()和puts()在头文件stdio.h中声明。这两个函数用于字符串的输入/输出操作。 C gets()函数 gets()函数使用户可以输入一些字符&#xff0c;然后按Enter键。 用户输入的所有字符都存储在字符数组中。 空字符将添加到数组以使其成为字符串。 gets()允…...

核—幂零分解

若向量空间 V \mathcal V V存在子空间 X \mathcal X X与 Y \mathcal Y Y&#xff0c;当 X Y V X ∩ Y 0 \mathcal {X\text{}Y\text{}V}\\ \mathcal {X}\cap \mathcal {Y}0 XYVX∩Y0 时称子空间 X \mathcal X X与 Y \mathcal Y Y是完备的&#xff0c;其中记为 X ⊕ Y V \ma…...

轻松掌控财务,分析账户花销,明细记录支出情况

随着科技的发展&#xff0c;我们的生活变得越来越智能化。然而&#xff0c;对于许多忙碌的现代人来说&#xff0c;管理财务可能是一件令人头疼的事情。复杂的账单、花销、收入&#xff0c;这些可能会让你感到无从下手。但现在&#xff0c;我们有一个全新的解决方案——一款全新…...

竞赛 题目:基于机器视觉opencv的手势检测 手势识别 算法 - 深度学习 卷积神经网络 opencv python

文章目录 1 简介2 传统机器视觉的手势检测2.1 轮廓检测法2.2 算法结果2.3 整体代码实现2.3.1 算法流程 3 深度学习方法做手势识别3.1 经典的卷积神经网络3.2 YOLO系列3.3 SSD3.4 实现步骤3.4.1 数据集3.4.2 图像预处理3.4.3 构建卷积神经网络结构3.4.4 实验训练过程及结果 3.5 …...

11. Spring源码篇之实例化前的后置处理器

简介 spring在创建Bean的过程中&#xff0c;提供了很多个生命周期&#xff0c;实例化前就是比较早的一个生命周期&#xff0c;顾名思义就是在Bean被实例化之前的处理&#xff0c;这个时候还没实例化&#xff0c;只能拿到该Bean的Class对象&#xff0c;如果在这个时候直接返回一…...

Python-Python高阶技巧:HTTP协议、静态Web服务器程序开发、循环接收客户端的连接请求

版本说明 当前版本号[20231114]。 版本修改说明20231114初版 目录 文章目录 版本说明目录HTTP协议1、网址1.1 网址的概念1.2 URL的组成1.3 知识要点 2、HTTP协议的介绍2.1 HTTP协议的概念及作用2.2 HTTP协议的概念及作用2.3 浏览器访问Web服务器的过程 3、HTTP请求报文3.1 H…...

P1304 哥德巴赫猜想

题目描述 输入一个偶数 N,验证 4∼N 所有偶数是否符合哥德巴赫猜想:任一大于 22 的偶数都可写成两个质数之和。如果一个数不止一种分法,则输出第一个加数相比其他分法最小的方案。例如 1010,10=3+7=5+510=3+7=5+5,则 10=5+510=5+5 是错误答案。 输入格式 第一行输入一个…...

CSDN每日一题学习训练——Python版(搜索插入位置、最大子序和)

版本说明 当前版本号[20231118]。 版本修改说明20231118初版 目录 文章目录 版本说明目录搜索插入位置题目解题思路代码思路参考代码 最大子序和题目解题思路代码思路参考代码 搜索插入位置 题目 给定一个排序数组和一个目标值&#xff0c;在数组中找到目标值&#xff0c;…...

Java在物联网中的重要性

【点我-这里送书】 本人详解 作者:王文峰,参加过 CSDN 2020年度博客之星,《Java王大师王天师》 公众号:JAVA开发王大师,专注于天道酬勤的 Java 开发问题中国国学、传统文化和代码爱好者的程序人生,期待你的关注和支持!本人外号:神秘小峯 山峯 转载说明:务必注明来源(…...

动态规划解背包问题

题目 题解 def knapsac(W: int, N: int, wt: List[int], val: List[int]) -> int:# 定义状态动作价值函数: dp[i][j]&#xff0c;对于前i个物品&#xff0c;当前背包容量为j&#xff0c;最大的可装载价值dp [[0 for j in range(W1)] for i in range(N1)]# 状态动作转移for…...

PCL内置点云类型

PCL内置了许多点云类型供我们使用&#xff0c;下面先介绍PLC内置的点云数据类型 PCL中的点云类型为PointT&#xff1b;至于为什么是PointT类型需要追随到原来的ros开发中去&#xff0c;因为PCL库也是从原来的ROS中剥离出来的&#xff1b;大家都一致的认为点云结构是离散的N维信…...

告别GPS模块!用IRIG-B码为你的工业设备打造超高性价比的10ns同步时钟源

工业级10ns同步时钟方案&#xff1a;IRIG-B解码模块的实战应用指南 在工业自动化、电力系统和精密测试测量领域&#xff0c;时间同步精度往往直接关系到系统运行的可靠性与数据采集的准确性。传统GPS/北斗模块虽然普及&#xff0c;却面临着信号覆盖受限、设备成本高昂以及潜在安…...

Spring Boot项目必备:用Arthas实现MyBatis Mapper热加载的完整配置流程

Spring Boot项目必备&#xff1a;用Arthas实现MyBatis Mapper热加载的完整配置流程 在持续交付的微服务架构中&#xff0c;开发团队经常面临一个共同挑战&#xff1a;每次修改MyBatis的Mapper XML文件后&#xff0c;都需要重启服务才能验证变更效果。这种低效的反馈循环严重拖慢…...

微信支付ApiV3回调实战:Java版签名校验与参数解密全流程解析

1. 微信支付ApiV3回调的核心流程 微信支付ApiV3的回调机制是整个支付流程中非常关键的一环。当用户完成支付后&#xff0c;微信服务器会主动向商户服务器发送支付结果通知。这个通知包含了支付状态、金额等重要信息&#xff0c;但为了确保数据安全&#xff0c;微信会对这些信息…...

Windows10下PaddleOCR与Python3.8.5的完美搭配:从安装到实战OCR识别

Windows10下PaddleOCR与Python3.8.5的深度实践指南 在数字化办公和自动化流程日益普及的今天&#xff0c;光学字符识别&#xff08;OCR&#xff09;技术已经成为从图像中提取文本信息的重要工具。PaddleOCR作为百度开源的OCR工具库&#xff0c;凭借其出色的识别准确率和易用性…...

告别网络调试焦虑:用STM32CubeMX+FreeRTOS,给LAN8720A和LWIP做个“健康检查”与性能小优化

STM32网络子系统深度优化&#xff1a;从连通性测试到工业级稳定性实战 当你熬夜调试的嵌入式设备终于能Ping通时&#xff0c;那种喜悦感堪比程序员第一次写出"Hello World"。但很快你会发现&#xff0c;真正的挑战才刚刚开始——那些在演示视频里永远不会出现的诡异断…...

IP离线库每周更新一次够用吗?企业风控建议多久更新?

在风控体系中&#xff0c;IP数据的时效性直接决定了拦截效果。当攻击者使用秒拨IP或住宅代理发起攻击时&#xff0c;IP地址的轮换速度可以达到分钟级。如果依赖的IP库更新周期过长&#xff0c;就等于在防御上留下了数天的空窗期。 周更不够用。秒拨IP平均存活3-5分钟&#xff…...

LeetCode138. 随机链表的复制(2024秋季每日一题 34)

给你一个长度为 n 的链表&#xff0c;每个节点包含一个额外增加的随机指针 random &#xff0c;该指针可以指向链表中的任何节点或空节点。 构造这个链表的 深拷贝。 深拷贝应该正好由 n 个 全新 节点组成&#xff0c;其中每个新节点的值都设为其对应的原节点的值。新节点的 ne…...

红蓝对抗深度解析:从技术体系到落地实践,企业安全真正的实战课

红蓝对抗深度解析&#xff1a;从技术体系到落地实践&#xff0c;企业安全真正的实战课 在数字化攻防进入 “实战对抗” 时代的今天&#xff0c;红蓝对抗已成为企业检验安全防御体系、提升应急响应能力的核心手段。不同于传统的漏洞扫描和合规检查&#xff0c;红蓝对抗以 “高仿…...

ROS2(2)配置:从WSL网络到Docker容器GUI显示的完整链路

1. WSL2网络架构解析与ROS2容器网络配置 在WSL2Docker环境中运行ROS2时&#xff0c;网络问题是最常见的拦路虎。我刚开始用这个组合时&#xff0c;经常遇到镜像拉取超时、容器内无法访问外网的情况&#xff0c;后来才发现问题出在对WSL2网络机制的理解不足上。 WSL2采用虚拟化技…...

Hermes邮件生成器详解:如何配置产品信息和自定义主题

Hermes邮件生成器详解&#xff1a;如何配置产品信息和自定义主题 【免费下载链接】hermes Golang package that generates clean, responsive HTML e-mails for sending transactional mail 项目地址: https://gitcode.com/gh_mirrors/he/hermes Hermes是一款强大的Go语…...