当前位置: 首页 > news >正文

(Transfer Learning)迁移学习在IMDB上训练情感分析模型

1. 背景

有些场景下,开始的时候数据量很小,如果我们用一个几千条数据训练一个全新的深度机器学习的文本分类模型,效果不会很好。这个时候你有两种选择,1.用传统的机器学习训练,2.利用迁移学习在一个预训练的模型上训练。本博客教你怎么用tensorflow Hub和keras 在少量的数据上训练一个文本分类模型。

2. 实践

2.1. 下载IMDB 数据集,参考下面博客。

Imdb影评的数据集介绍与下载_imdb影评数据集-CSDN博客

2.2.  预处理数据

替换掉imdb目录 (imdb_raw_data_dir). 创建dataset目录。

import numpy as np
import os as osimport re
from sklearn.model_selection import train_test_splitvocab_size = 30000
maxlen = 200
imdb_raw_data_dir = "/Users/harry/Documents/apps/ml/aclImdb"
save_dir = "dataset"def get_data(datapath =r'D:\train_data\aclImdb\aclImdb\train' ):pos_files = os.listdir(datapath + '/pos')neg_files = os.listdir(datapath + '/neg')print(len(pos_files))print(len(neg_files))pos_all = []neg_all = []for pf, nf in zip(pos_files, neg_files):with open(datapath + '/pos' + '/' + pf, encoding='utf-8') as f:s = f.read()s = process(s)pos_all.append(s)with open(datapath + '/neg' + '/' + nf, encoding='utf-8') as f:s = f.read()s = process(s)neg_all.append(s)print(len(pos_all))# print(pos_all[0])print(len(neg_all))X_orig= np.array(pos_all + neg_all)# print(X_orig)Y_orig = np.array([1 for _ in range(len(pos_all))] + [0 for _ in range(len(neg_all))])print("X_orig:", X_orig.shape)print("Y_orig:", Y_orig.shape)return X_orig, Y_origdef generate_dataset():X_orig, Y_orig = get_data(imdb_raw_data_dir + r'/train')X_orig_test, Y_orig_test = get_data(imdb_raw_data_dir + r'/test')X_orig = np.concatenate([X_orig, X_orig_test])Y_orig = np.concatenate([Y_orig, Y_orig_test])X = X_origY = Y_orignp.random.seed = 1random_indexs = np.random.permutation(len(X))X = X[random_indexs]Y = Y[random_indexs]X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.3)print("X_train:", X_train.shape)print("y_train:", y_train.shape)print("X_test:", X_test.shape)print("y_test:", y_test.shape)np.savez(save_dir + '/train_test', X_train=X_train, y_train=y_train, X_test= X_test, y_test=y_test )def rm_tags(text):re_tag = re.compile(r'<[^>]+>')return re_tag.sub(' ', text)def clean_str(string):string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)string = re.sub(r"\'s", " \'s", string)  # it's -> it 'sstring = re.sub(r"\'ve", " \'ve", string) # I've -> I 'vestring = re.sub(r"n\'t", " n\'t", string) # doesn't -> does n'tstring = re.sub(r"\'re", " \'re", string) # you're -> you arestring = re.sub(r"\'d", " \'d", string)  # you'd -> you 'dstring = re.sub(r"\'ll", " \'ll", string) # you'll -> you 'llstring = re.sub(r"\'m", " \'m", string) # I'm -> I 'mstring = re.sub(r",", " , ", string)string = re.sub(r"!", " ! ", string)string = re.sub(r"\(", " \( ", string)string = re.sub(r"\)", " \) ", string)string = re.sub(r"\?", " \? ", string)string = re.sub(r"\s{2,}", " ", string)return string.strip().lower()def process(text):text = clean_str(text)text = rm_tags(text)#text = text.lower()return  textif __name__ == '__main__':generate_dataset()

执行完后,产生train_test.npz 文件

2.3.  训练模型

1. 取数据集

def get_dataset_to_train():train_test = np.load('dataset/train_test.npz', allow_pickle=True)x_train =  train_test['X_train']y_train = train_test['y_train']x_test =  train_test['X_test']y_test = train_test['y_test']return x_train, y_train, x_test, y_test

2. 创建模型

基于nnlm-en-dim50/2 预训练的文本嵌入向量,在模型外面加了两层全连接。

def get_model():hub_layer = hub.KerasLayer(embedding_url, input_shape=[], dtype=tf.string, trainable=True)# Build the modelmodel = Sequential([hub_layer,Dense(16, activation='relu'),Dropout(0.5),Dense(2, activation='softmax')])print(model.summary())model.compile(optimizer=keras.optimizers.Adam(),loss=keras.losses.SparseCategoricalCrossentropy(),metrics=[keras.metrics.SparseCategoricalAccuracy()])return model

还可以使用来自 TFHub 的许多其他预训练文本嵌入向量:

  • google/nnlm-en-dim128/2 - 基于与 google/nnlm-en-dim50/2 相同的数据并使用相同的 NNLM 架构进行训练,但具有更大的嵌入向量维度。更大维度的嵌入向量可以改进您的任务,但可能需要更长的时间来训练您的模型。
  • google/nnlm-en-dim128-with-normalization/2 - 与 google/nnlm-en-dim128/2 相同,但具有额外的文本归一化,例如移除标点符号。如果您的任务中的文本包含附加字符或标点符号,这会有所帮助。
  • google/universal-sentence-encoder/4 - 一个可产生 512 维嵌入向量的更大模型,使用深度平均网络 (DAN) 编码器训练。

还有很多!在 TFHub 上查找更多文本嵌入向量模型。

3. 评估你的模型

def evaluate_model(test_data, test_labels):model = load_trained_model()# Evaluate the modelresults = model.evaluate(test_data, test_labels, verbose=2)print("Test accuracy:", results[1])def load_trained_model():# model = get_model()# model.load_weights('./models/model_new1.h5')model = tf.keras.models.load_model('models_pb')return model

4. 测试几个例子

def predict(real_data):model  = load_trained_model()probabilities = model.predict([real_data]);print("probabilities :",probabilities)result =  get_label(probabilities)return resultdef get_label(probabilities):index = np.argmax(probabilities[0])print("index :" + str(index))result_str =  index_dic.get(str(index))# result_str = list(index_dic.keys())[list(index_dic.values()).index(index)]return result_strdef predict_my_module():# review = "I don't like it"# review = "this is bad movie "# review = "This is good movie"review = " this is terrible movie"# review = "This isn‘t great movie"# review = "i think this is bad movie"# review = "I'm not very disappoint for this movie"# review = "I'm not very disappoint for this movie"# review = "I am very happy for this movie"#neg:0 postive:1s = predict(review)print(s)if __name__ == '__main__':x_train, y_train, x_test, y_test = get_dataset_to_train()model = get_model()model = train(model, x_train, y_train, x_test, y_test)evaluate_model(x_test, y_test)predict_my_module()

完整代码

import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense, Dropout
import keras as keras
from keras.callbacks import EarlyStopping, ModelCheckpoint
import tensorflow_hub as hubembedding_url = "https://tfhub.dev/google/nnlm-en-dim50/2"index_dic = {"0":"negative", "1": "positive"}def get_dataset_to_train():train_test = np.load('dataset/train_test.npz', allow_pickle=True)x_train =  train_test['X_train']y_train = train_test['y_train']x_test =  train_test['X_test']y_test = train_test['y_test']return x_train, y_train, x_test, y_testdef get_model():hub_layer = hub.KerasLayer(embedding_url, input_shape=[], dtype=tf.string, trainable=True)# Build the modelmodel = Sequential([hub_layer,Dense(16, activation='relu'),Dropout(0.5),Dense(2, activation='softmax')])print(model.summary())model.compile(optimizer=keras.optimizers.Adam(),loss=keras.losses.SparseCategoricalCrossentropy(),metrics=[keras.metrics.SparseCategoricalAccuracy()])return modeldef train(model , train_data, train_labels, test_data, test_labels):# train_data, train_labels, test_data, test_labels = get_dataset_to_train()train_data = [tf.compat.as_str(tf.compat.as_bytes(str(x))) for x in train_data]test_data = [tf.compat.as_str(tf.compat.as_bytes(str(x))) for x in test_data]train_data = np.asarray(train_data)  # Convert to numpy arraytest_data = np.asarray(test_data)  # Convert to numpy arrayprint(train_data.shape, test_data.shape)early_stop = EarlyStopping(monitor='val_sparse_categorical_accuracy', patience=4, mode='max', verbose=1)# 定义ModelCheckpoint回调函数# checkpoint = ModelCheckpoint( './models/model_new1.h5', monitor='val_sparse_categorical_accuracy', save_best_only=True,#                              mode='max', verbose=1)checkpoint_pb = ModelCheckpoint(filepath="./models_pb/",  monitor='val_sparse_categorical_accuracy', save_weights_only=False, save_best_only=True)history = model.fit(train_data[:2000], train_labels[:2000], epochs=45, batch_size=45, validation_data=(test_data, test_labels), shuffle=True,verbose=1, callbacks=[early_stop, checkpoint_pb])print("history", history)return modeldef evaluate_model(test_data, test_labels):model = load_trained_model()# Evaluate the modelresults = model.evaluate(test_data, test_labels, verbose=2)print("Test accuracy:", results[1])def predict(real_data):model  = load_trained_model()probabilities = model.predict([real_data]);print("probabilities :",probabilities)result =  get_label(probabilities)return resultdef get_label(probabilities):index = np.argmax(probabilities[0])print("index :" + str(index))result_str =  index_dic.get(str(index))# result_str = list(index_dic.keys())[list(index_dic.values()).index(index)]return result_strdef load_trained_model():# model = get_model()# model.load_weights('./models/model_new1.h5')model = tf.keras.models.load_model('models_pb')return modeldef predict_my_module():# review = "I don't like it"# review = "this is bad movie "# review = "This is good movie"review = " this is terrible movie"# review = "This isn‘t great movie"# review = "i think this is bad movie"# review = "I'm not very disappoint for this movie"# review = "I'm not very disappoint for this movie"# review = "I am very happy for this movie"#neg:0 postive:1s = predict(review)print(s)if __name__ == '__main__':x_train, y_train, x_test, y_test = get_dataset_to_train()model = get_model()model = train(model, x_train, y_train, x_test, y_test)evaluate_model(x_test, y_test)predict_my_module()

相关文章:

(Transfer Learning)迁移学习在IMDB上训练情感分析模型

1. 背景 有些场景下&#xff0c;开始的时候数据量很小&#xff0c;如果我们用一个几千条数据训练一个全新的深度机器学习的文本分类模型&#xff0c;效果不会很好。这个时候你有两种选择&#xff0c;1.用传统的机器学习训练&#xff0c;2.利用迁移学习在一个预训练的模型上训练…...

蓝桥杯每日一题2023.11.20

题目描述 “蓝桥杯”练习系统 (lanqiao.cn) 题目分析 方法一&#xff1a;暴力枚举&#xff0c;如果说数字不在正确的位置上也就意味着这个数必须要改变&#xff0c;进行改变记录即可 #include<bits/stdc.h> using namespace std; const int N 2e5 10; int n, a[N], …...

【迅搜02】究竟什么是搜索引擎?正式介绍XunSearch

究竟什么是搜索引擎&#xff1f;正式介绍XunSearch 啥&#xff1f;还要单独讲一下啥是搜索引擎&#xff1f;不就是百度、Google嘛&#xff0c;这玩意天天用&#xff0c;还轮的到你来说&#xff1f; 额&#xff0c;好吧&#xff0c;虽然大家天天都在用&#xff0c;但是我发现&am…...

【Sql】sql server还原数据库的时候,提示:因为数据库正在使用,所以无法获得对数据库的独占访问权。

【问题描述】 sql server 还数据库的时候&#xff0c;提示失败。 点击左下角进度位置&#xff0c;可以得到详细信息&#xff1a; 因为数据库正在使用&#xff0c;所以无法获得对数据库的独占访问权。 【解决方法】 针对数据库先后执行下述语句&#xff0c;获得独占访问权后&a…...

【Go语言实战】(26) 分布式搜索引擎

Tangseng 基于Go语言的搜索引擎 github地址&#xff1a;https://github.com/CocaineCong/tangseng 详细介绍地址&#xff1a;https://cocainecong.github.io/tangseng 这两周我也抽空录成视频发到B站的&#xff5e; 本来应该10月份就要发了&#xff0c;结果一鸽就鸽到现在hh…...

【理解ARM架构】不同方式点灯 | ARM架构简介 | 常见汇编指令 | C与汇编

&#x1f431;作者&#xff1a;一只大喵咪1201 &#x1f431;专栏&#xff1a;《理解ARM架构》 &#x1f525;格言&#xff1a;你只管努力&#xff0c;剩下的交给时间&#xff01; 目录 &#x1f3c0;直接操作寄存器点亮LED灯&#x1f3c0;地址空间&#x1f3c0;ARM内部的寄存…...

JS服务端技术—Node.js知识点锦集

【版权声明】未经博主同意&#xff0c;谢绝转载&#xff01;&#xff08;请尊重原创&#xff0c;博主保留追究权&#xff09; https://blog.csdn.net/m0_69908381/article/details/134544523 出自【进步*于辰的博客】 接触Node.js挺长时间了&#xff0c;工作也经常使用&#xf…...

界面控件DevExpress WPF流程图组件,完美复制Visio UI!(一)

DevExpress WPF Diagram&#xff08;流程图&#xff09;控件帮助用户完美复制Microsoft Visio UI&#xff0c;并将信息丰富且组织良好的图表、流程图和组织图轻松合并到您的下一个WPF项目中。 P.S&#xff1a;DevExpress WPF拥有120个控件和库&#xff0c;将帮助您交付满足甚至…...

为什么选择B+树作为数据库索引结构?

背景 首先&#xff0c;来谈谈B树。为什么要使用B树&#xff1f;我们需要明白以下两个事实&#xff1a; 【事实1】 不同容量的存储器&#xff0c;访问速度差异悬殊。以磁盘和内存为例&#xff0c;访问磁盘的时间大概是ms级的&#xff0c;访问内存的时间大概是ns级的。有个形象…...

什么是神经网络(Neural Network,NN)

1 定义 神经网络是一种模拟人类大脑工作方式的计算模型&#xff0c;它是深度学习和机器学习领域的基础。神经网络由大量的节点&#xff08;或称为“神经元”&#xff09;组成&#xff0c;这些节点在网络中相互连接&#xff0c;可以处理复杂的数据输入&#xff0c;执行各种任务…...

15 Go的并发

概述 在上一节的内容中&#xff0c;我们介绍了Go的类型转换&#xff0c;包括&#xff1a;断言类型转换、显式类型转换、隐式类型转换、strconv包等。在本节中&#xff0c;我们将介绍Go的并发。Go语言以其强大的并发模型而闻名&#xff0c;其并发特性主要通过以下几个元素来实现…...

管理体系标准

管理体系标准 什么是管理体系&#xff1f; 管理体系是组织管理其业务的相互关联部分以实现其目标的方式。这些目标可能涉及许多不同的主题&#xff0c;包括产品或服务质量、运营效率、环境绩效、工作场所的健康和安全等等。 系统的复杂程度取决于每个组织的具体情况。对于某…...

【Java 进阶篇】揭秘 Jackson:Java 对象转 JSON 注解的魔法

嗨&#xff0c;亲爱的同学们&#xff01;欢迎来到这篇关于 Jackson JSON 解析器中 Java 对象转 JSON 注解的详细解析指南。JSON&#xff08;JavaScript Object Notation&#xff09;是一种常用于数据交换的轻量级数据格式&#xff0c;而 Jackson 作为一款优秀的 JSON 解析库&am…...

②【Hash】Redis常用数据类型:Hash [使用手册]

个人简介&#xff1a;Java领域新星创作者&#xff1b;阿里云技术博主、星级博主、专家博主&#xff1b;正在Java学习的路上摸爬滚打&#xff0c;记录学习的过程~ 个人主页&#xff1a;.29.的博客 学习社区&#xff1a;进去逛一逛~ Redis Hash ②Redis Hash 操作命令汇总1. hset…...

十七、SpringAMQP

目录 一、SpringAMQP的介绍&#xff1a; 二、利用SpringAMQP实现HelloWorld中的基础消息队列功能 1、因为publisher和consumer服务都需要amqp依赖&#xff0c;因此这里把依赖直接放到父工程mq-demo中 2、编写yml文件 3、编写测试类&#xff0c;并进行测试 三、在consumer…...

Java虚拟机(JVM)的调优技巧和实战

JVM是Java应用程序的运行环境&#xff0c;它负责管理Java应用程序的内存分配、垃圾收集等重要任务。然而&#xff0c;JVM的默认设置并不总是适合所有应用程序&#xff0c;因此需要根据应用程序的需求进行调优。通过对JVM进行调优&#xff0c;可以大大提高Java应用程序的性能和可…...

idea中的sout、psvm快捷键输入,不要太好用了

目录 一、操作环境 二、psvm、sout 操作介绍 2.1 psvm&#xff0c;快捷生成main方法 2.2 sout&#xff0c;快捷生成打印方法 三、探索 psvm、sout 底层逻辑 一、操作环境 语言&#xff1a;Java 工具&#xff1a; 二、psvm、sout 操作介绍 2.1 psvm&#xff0c;快捷生成m…...

shell脚本字典创建遍历打印

解释&#xff1a; 代码块中包含了每个用法的详细解释 #!/bin/bash# 接收用户输入的两个数 echo "请输入第一个数&#xff1a;" read num1 echo "请输入第二个数&#xff1a;" read num2# 创建一个关联数组 declare -A dict1 declare -A dict2# 定义键和值…...

【设计模式】聊聊职责链模式

原理和实现 模板模式变化的是其中一个步骤&#xff0c;而责任链模式变化的是整个流程。 将请求的发送和接收解耦合&#xff0c;让多个接收对象有机会可以处理这个请求&#xff0c;形成一个链条。不同的处理器负责自己不同的职责。 定义接口 public interface Filter {/*** …...

【C++进阶之路】第五篇:哈希

文章目录 一、unordered系列关联式容器1.unordered_map&#xff08;1&#xff09;unordered_map的介绍&#xff08;2&#xff09;unordered_map的接口说明 2. unordered_set3.性能对比 二、底层结构1.哈希概念2.哈希冲突3.哈希函数4.哈希冲突解决&#xff08;1&#xff09;闭散…...

浏览器访问 AWS ECS 上部署的 Docker 容器(监听 80 端口)

✅ 一、ECS 服务配置 Dockerfile 确保监听 80 端口 EXPOSE 80 CMD ["nginx", "-g", "daemon off;"]或 EXPOSE 80 CMD ["python3", "-m", "http.server", "80"]任务定义&#xff08;Task Definition&…...

日语AI面试高效通关秘籍:专业解读与青柚面试智能助攻

在如今就业市场竞争日益激烈的背景下&#xff0c;越来越多的求职者将目光投向了日本及中日双语岗位。但是&#xff0c;一场日语面试往往让许多人感到步履维艰。你是否也曾因为面试官抛出的“刁钻问题”而心生畏惧&#xff1f;面对生疏的日语交流环境&#xff0c;即便提前恶补了…...

day52 ResNet18 CBAM

在深度学习的旅程中&#xff0c;我们不断探索如何提升模型的性能。今天&#xff0c;我将分享我在 ResNet18 模型中插入 CBAM&#xff08;Convolutional Block Attention Module&#xff09;模块&#xff0c;并采用分阶段微调策略的实践过程。通过这个过程&#xff0c;我不仅提升…...

基于服务器使用 apt 安装、配置 Nginx

&#x1f9fe; 一、查看可安装的 Nginx 版本 首先&#xff0c;你可以运行以下命令查看可用版本&#xff1a; apt-cache madison nginx-core输出示例&#xff1a; nginx-core | 1.18.0-6ubuntu14.6 | http://archive.ubuntu.com/ubuntu focal-updates/main amd64 Packages ng…...

数据链路层的主要功能是什么

数据链路层&#xff08;OSI模型第2层&#xff09;的核心功能是在相邻网络节点&#xff08;如交换机、主机&#xff09;间提供可靠的数据帧传输服务&#xff0c;主要职责包括&#xff1a; &#x1f511; 核心功能详解&#xff1a; 帧封装与解封装 封装&#xff1a; 将网络层下发…...

Java 加密常用的各种算法及其选择

在数字化时代&#xff0c;数据安全至关重要&#xff0c;Java 作为广泛应用的编程语言&#xff0c;提供了丰富的加密算法来保障数据的保密性、完整性和真实性。了解这些常用加密算法及其适用场景&#xff0c;有助于开发者在不同的业务需求中做出正确的选择。​ 一、对称加密算法…...

深入解析C++中的extern关键字:跨文件共享变量与函数的终极指南

&#x1f680; C extern 关键字深度解析&#xff1a;跨文件编程的终极指南 &#x1f4c5; 更新时间&#xff1a;2025年6月5日 &#x1f3f7;️ 标签&#xff1a;C | extern关键字 | 多文件编程 | 链接与声明 | 现代C 文章目录 前言&#x1f525;一、extern 是什么&#xff1f;&…...

CMake控制VS2022项目文件分组

我们可以通过 CMake 控制源文件的组织结构,使它们在 VS 解决方案资源管理器中以“组”(Filter)的形式进行分类展示。 🎯 目标 通过 CMake 脚本将 .cpp、.h 等源文件分组显示在 Visual Studio 2022 的解决方案资源管理器中。 ✅ 支持的方法汇总(共4种) 方法描述是否推荐…...

#Uniapp篇:chrome调试unapp适配

chrome调试设备----使用Android模拟机开发调试移动端页面 Chrome://inspect/#devices MuMu模拟器Edge浏览器&#xff1a;Android原生APP嵌入的H5页面元素定位 chrome://inspect/#devices uniapp单位适配 根路径下 postcss.config.js 需要装这些插件 “postcss”: “^8.5.…...

短视频矩阵系统文案创作功能开发实践,定制化开发

在短视频行业迅猛发展的当下&#xff0c;企业和个人创作者为了扩大影响力、提升传播效果&#xff0c;纷纷采用短视频矩阵运营策略&#xff0c;同时管理多个平台、多个账号的内容发布。然而&#xff0c;频繁的文案创作需求让运营者疲于应对&#xff0c;如何高效产出高质量文案成…...