中文文本分类(pytorch 实现)
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import os, PIL, pathlib, warningswarnings.filterwarnings("ignore") # 忽略警告信息# win10系统
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
train.csv 链接:https://pan.baidu.com/s/1Vnyvo5T5eSuzb0VwTsznqA?pwd=fqok 提取码:fqok
import pandas as pd# 加载自定义中文数据集
train_data = pd.read_csv('D:/train.csv', sep='\t', header=None)
train_data.head()# 构建数据集迭代器
def coustom_data_iter(texts, labels):for x, y in zip(texts, labels):yield x, ytrain_iter = coustom_data_iter(train_data[0].values[:], train_data[1].values[:])
1.构建词典:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import jieba# 中文分词方法
tokenizer = jieba.lcutdef yield_tokens(data_iter):for text, in data_iter:yield tokenizer(text)vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])
调用vocab(词汇表)对一个中文句子进行索引转换,这个句子被分词后得到的词汇列表会被转换成它们在词汇表中的索引。
print(vocab(['我', '想', '看', '书', '和', '你', '一起', '看', '电影', '的', '新款', '视频']))
生成一个标签列表,用于查看在数据集中所有可能的标签类型。
label_name = list(set(train_data[1].values[:]))
print(label_name)
创建了两个lambda函数,一个用于将文本转换成词汇索引,另一个用于将标签文本转换成它们在label_name列表中的索引。
text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: label_name.index(x)print(text_pipeline('我想看新闻或者上网站看最新的游戏视频'))
print(label_pipeline('Video-Play'))
2.生成数据批次和迭代器
from torch.utils.data import DataLoaderdef collate_batch(batch):label_list, text_list, offsets = [], [], [0]for (_text, _label) in batch:# 标签列表label_list.append(label_pipeline(_label))# 文本列表processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)text_list.append(processed_text)# 偏移量,即词汇的起始位置offsets.append(processed_text.size(0))label_list = torch.tensor(label_list, dtype=torch.int64)text_list = torch.cat(text_list)offsets = torch.tensor(offsets[:-1]).cumsum(dim=0) # 累计偏移量dim中维度元素的累计和return text_list.to(device), label_list.to(device), offsets.to(device)# 数据加载器,调用示例
dataloader = DataLoader(train_iter,batch_size=8,shuffle=False,collate_fn=collate_batch)
collate_batch函数用于处理数据加载器中的批次。它接收一个批次的数据,处理它,并返回适合模型训练的数据格式。
在这个函数内部,它遍历批次中的每个文本和标签对,将标签添加到label_list,将文本通过text_pipeline函数处理后转换为tensor,并添加到text_list。
offsets列表用于存储每个文本的长度,这对于后续的文本处理非常有用,尤其是当你需要知道每个文本在拼接的大tensor中的起始位置时。
text_list用torch.cat进行拼接,形成一个连续的tensor。
offsets列表的最后一个元素不包括,然后使用cumsum函数在第0维计算累积和,这为每个序列提供了一个累计的偏移量。
3.搭建模型与初始化
from torch import nnclass TextClassificationModel(nn.Module):def __init__(self, vocab_size, embed_dim, num_class):super(TextClassificationModel, self).__init__()self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)self.fc = nn.Linear(embed_dim, num_class)self.init_weights()def init_weights(self):initrange = 0.5self.embedding.weight.data.uniform_(-initrange, initrange)self.fc.weight.data.uniform_(-initrange, initrange)self.fc.bias.data.zero_()def forward(self, text, offsets):embedded = self.embedding(text, offsets)return self.fc(embedded)num_class = len(label_name) # 类别数,根据label_name的长度确定
vocab_size = len(vocab) # 词汇表的大小,根据vocab的长度确定
em_size = 64 # 嵌入向量的维度设置为64
model = TextClassificationModel(vocab_size, em_size, num_class).to(device) # 创建模型实例并移动到计算设备
4.模型训练及评估函数
train 和 evaluate分别用于训练和评估文本分类模型。
训练函数 train 的工作流程如下:
将模型设置为训练模式。
初始化总准确率、训练损失和总计数变量。
记录训练开始的时间。
遍历数据加载器,对每个批次:
进行预测。
清零优化器的梯度。
计算损失(使用一个损失函数,例如交叉熵)。
反向传播计算梯度。
通过梯度裁剪防止梯度爆炸。
执行一步优化器更新模型权重。
更新总准确率和总损失。
每隔一定间隔,打印训练进度和统计信息。
评估函数 evaluate 的工作流程如下:
将模型设置为评估模式。
初始化总准确率和总损失。
不计算梯度(为了节省内存和计算资源)。
遍历数据加载器,对每个批次:
进行预测。
计算损失。
更新总准确率和总损失。
返回整体的准确率和平均损失。
代码实现:
import timedef train(dataloader):model.train() # 切换到训练模式total_acc, train_loss, total_count = 0, 0, 0log_interval = 50start_time = time.time()for idx, (text, label, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)optimizer.zero_grad() # 梯度归零loss = criterion(predicted_label, label) # 计算损失loss.backward() # 反向传播torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) # 梯度裁剪optimizer.step() # 优化器更新权重# 记录acc和losstotal_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('| epoch {:3d} | {:5d}/{:5d} batches ''| accuracy {:8.3f} | loss {:8.5f}'.format(epoch, idx, len(dataloader),total_acc/total_count, train_loss/total_count))total_acc, train_loss, total_count = 0, 0, 0start_time = time.time()def evaluate(dataloader):model.eval() # 切换到评估模式total_acc, total_count = 0, 0with torch.no_grad():for idx, (text, label, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)loss = criterion(predicted_label, label) # 计算losstotal_acc += (predicted_label.argmax(1) == label).sum().item()total_count += label.size(0)return total_acc/total_count, total_count
5.模型训练
设置训练的轮数、学习率和批次大小。
定义交叉熵损失函数、随机梯度下降优化器和学习率调度器。
将训练数据转换为一个map样式的数据集,并将其分成训练集和验证集。
创建训练和验证的数据加载器。
开始训练循环,每个epoch都会训练模型并在验证集上评估模型的准确率和损失。
如果验证准确率没有提高,则按计划降低学习率。
打印每个epoch结束时的统计信息,包括时间、准确率、损失和学习率。
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
# 参数设置
EPOCHS = 10 # epoch数量
LR = 5 # 学习速率
BATCH_SIZE = 64 # 训练的batch大小# 设置损失函数、优化器和调度器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None# 准备数据集
train_iter = coustom_data_iter(train_data[0].values[:], train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)split_train_, split_valid_ = random_split(train_dataset,[int(len(train_dataset)*0.8), int(len(train_dataset)*0.2)])train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)# 训练循环
for epoch in range(1, EPOCHS + 1):epoch_start_time = time.time()train(train_dataloader)val_acc, val_loss = evaluate(valid_dataloader)# 更新学习率的策略lr = optimizer.state_dict()['param_groups'][0]['lr']if total_accu is not None and total_accu > val_acc:scheduler.step()else:total_accu = val_accprint('-' * 69)print('| end of epoch {:3d} | time: {:4.2f}s | ''valid accuracy {:4.3f} | valid loss {:4.3f} | lr {:4.6f}'.format(epoch, time.time() - epoch_start_time, val_acc, val_loss, lr))print('-' * 69)
运行结果:
| epoch 1 | 50/ 152 batches | accuracy 0.423 | loss 0.03079
| epoch 1 | 100/ 152 batches | accuracy 0.700 | loss 0.01912
| epoch 1 | 150/ 152 batches | accuracy 0.776 | loss 0.01347
---------------------------------------------------------------------
| end of epoch 1 | time: 1.53s | valid accuracy 0.777 | valid loss 2420.000 | lr 5.000000
| epoch 2 | 50/ 152 batches | accuracy 0.812 | loss 0.01056
| epoch 2 | 100/ 152 batches | accuracy 0.843 | loss 0.00871
| epoch 2 | 150/ 152 batches | accuracy 0.844 | loss 0.00846
---------------------------------------------------------------------
| end of epoch 2 | time: 1.45s | valid accuracy 0.842 | valid loss 2420.000 | lr 5.000000
| epoch 3 | 50/ 152 batches | accuracy 0.883 | loss 0.00653
| epoch 3 | 100/ 152 batches | accuracy 0.879 | loss 0.00634
| epoch 3 | 150/ 152 batches | accuracy 0.883 | loss 0.00627
---------------------------------------------------------------------
| end of epoch 3 | time: 1.44s | valid accuracy 0.865 | valid loss 2420.000 | lr 5.000000
| epoch 4 | 50/ 152 batches | accuracy 0.912 | loss 0.00498
| epoch 4 | 100/ 152 batches | accuracy 0.906 | loss 0.00495
| epoch 4 | 150/ 152 batches | accuracy 0.915 | loss 0.00461
---------------------------------------------------------------------
| end of epoch 4 | time: 1.50s | valid accuracy 0.876 | valid loss 2420.000 | lr 5.000000
| epoch 5 | 50/ 152 batches | accuracy 0.935 | loss 0.00386
| epoch 5 | 100/ 152 batches | accuracy 0.934 | loss 0.00390
| epoch 5 | 150/ 152 batches | accuracy 0.932 | loss 0.00362
---------------------------------------------------------------------
| end of epoch 5 | time: 1.59s | valid accuracy 0.881 | valid loss 2420.000 | lr 5.000000
| epoch 6 | 50/ 152 batches | accuracy 0.947 | loss 0.00313
| epoch 6 | 100/ 152 batches | accuracy 0.949 | loss 0.00307
| epoch 6 | 150/ 152 batches | accuracy 0.949 | loss 0.00286
---------------------------------------------------------------------
| end of epoch 6 | time: 1.68s | valid accuracy 0.891 | valid loss 2420.000 | lr 5.000000
| epoch 7 | 50/ 152 batches | accuracy 0.960 | loss 0.00243
| epoch 7 | 100/ 152 batches | accuracy 0.963 | loss 0.00224
| epoch 7 | 150/ 152 batches | accuracy 0.959 | loss 0.00252
---------------------------------------------------------------------
| end of epoch 7 | time: 1.53s | valid accuracy 0.892 | valid loss 2420.000 | lr 5.000000
| epoch 8 | 50/ 152 batches | accuracy 0.972 | loss 0.00186
| epoch 8 | 100/ 152 batches | accuracy 0.974 | loss 0.00184
| epoch 8 | 150/ 152 batches | accuracy 0.967 | loss 0.00201
---------------------------------------------------------------------
| end of epoch 8 | time: 1.43s | valid accuracy 0.895 | valid loss 2420.000 | lr 5.000000
| epoch 9 | 50/ 152 batches | accuracy 0.981 | loss 0.00138
| epoch 9 | 100/ 152 batches | accuracy 0.977 | loss 0.00165
| epoch 9 | 150/ 152 batches | accuracy 0.980 | loss 0.00147
---------------------------------------------------------------------
| end of epoch 9 | time: 1.48s | valid accuracy 0.900 | valid loss 2420.000 | lr 5.000000
| epoch 10 | 50/ 152 batches | accuracy 0.987 | loss 0.00117
| epoch 10 | 100/ 152 batches | accuracy 0.985 | loss 0.00121
| epoch 10 | 150/ 152 batches | accuracy 0.984 | loss 0.00121
---------------------------------------------------------------------
| end of epoch 10 | time: 1.45s | valid accuracy 0.902 | valid loss 2420.000 | lr 5.000000
---------------------------------------------------------------------
6.模型评估
test_acc, test_loss = evaluate(valid_dataloader)
print('模型的准确率: {:5.4f}'.format(test_acc))
7.模型测试
def predict(text, text_pipeline):with torch.no_grad():text = torch.tensor(text_pipeline(text))output = model(text, torch.tensor([0]))return output.argmax(1).item()# 示例文本字符串
# ex_text_str = "例句输入——这是一个待预测类别的示例句子"
ex_text_str = "这不仅影响到我们的方案是否可行13号的"model = model.to("cpu")print("该文本的类别是: %s" % label_name[predict(ex_text_str, text_pipeline)])
8.全部代码(部分修改):
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import os, PIL, pathlib, warningswarnings.filterwarnings("ignore") # 忽略警告信息# win10系统
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)import pandas as pd# 加载自定义中文数据集
train_data = pd.read_csv('D:/train.csv', sep='\t', header=None)
train_data.head()# 构建数据集迭代器
def custom_data_iter(texts, labels):for x, y in zip(texts, labels):yield x, ytrain_iter = custom_data_iter(train_data[0].values[:], train_data[1].values[:])from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import jieba# 中文分词方法
tokenizer = jieba.lcutdef yield_tokens(data_iter):for text,_ in data_iter:yield tokenizer(text)vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])print(vocab(['我', '想', '看', '书', '和', '你', '一起', '看', '电影', '的', '新款', '视频']))label_name = list(set(train_data[1].values[:]))
print(label_name)text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: label_name.index(x)print(text_pipeline('我想看新闻或者上网站看最新的游戏视频'))
print(label_pipeline('Video-Play'))from torch.utils.data import DataLoaderdef collate_batch(batch):label_list, text_list, offsets = [], [], [0]for (_text, _label) in batch:# 标签列表label_list.append(label_pipeline(_label))# 文本列表processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)text_list.append(processed_text)# 偏移量,即词汇的起始位置offsets.append(processed_text.size(0))label_list = torch.tensor(label_list, dtype=torch.int64)text_list = torch.cat(text_list)offsets = torch.tensor(offsets[:-1]).cumsum(dim=0) # 累计偏移量dim中维度元素的累计和return text_list.to(device), label_list.to(device), offsets.to(device)# 数据加载器,调用示例
dataloader = DataLoader(train_iter,batch_size=8,shuffle=False,collate_fn=collate_batch)from torch import nnclass TextClassificationModel(nn.Module):def __init__(self, vocab_size, embed_dim, num_class):super(TextClassificationModel, self).__init__()self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)self.fc = nn.Linear(embed_dim, num_class)self.init_weights()def init_weights(self):initrange = 0.5self.embedding.weight.data.uniform_(-initrange, initrange)self.fc.weight.data.uniform_(-initrange, initrange)self.fc.bias.data.zero_()def forward(self, text, offsets):embedded = self.embedding(text, offsets)return self.fc(embedded)
num_class = len(label_name)
vocab_size = len(vocab)
em_size = 64
model = TextClassificationModel(vocab_size, em_size, num_class).to(device)import timedef train(dataloader):model.train() # 切换到训练模式total_acc, train_loss, total_count = 0, 0, 0log_interval = 50start_time = time.time()for idx, (text, label, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)optimizer.zero_grad() # 梯度归零loss = criterion(predicted_label, label) # 计算损失loss.backward() # 反向传播torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) # 梯度裁剪optimizer.step() # 优化器更新权重# 记录acc和losstotal_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('| epoch {:3d} | {:5d}/{:5d} batches ''| accuracy {:8.3f} | loss {:8.5f}'.format(epoch, idx, len(dataloader),total_acc/total_count, train_loss/total_count))total_acc, train_loss, total_count = 0, 0, 0start_time = time.time()def evaluate(dataloader):model.eval() # 切换到评估模式total_acc, total_count = 0, 0with torch.no_grad():for idx, (text, label, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)loss = criterion(predicted_label, label) # 计算losstotal_acc += (predicted_label.argmax(1) == label).sum().item()total_count += label.size(0)return total_acc/total_count, total_countfrom torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
# 参数设置
EPOCHS = 10 # epoch数量
LR = 5 # 学习速率
BATCH_SIZE = 64 # 训练的batch大小# 设置损失函数、优化器和调度器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None# 准备数据集
train_iter = custom_data_iter(train_data[0].values[:], train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)split_train_, split_valid_ = random_split(train_dataset,[int(len(train_dataset)*0.8), int(len(train_dataset)*0.2)])train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)# 训练循环
for epoch in range(1, EPOCHS + 1):epoch_start_time = time.time()train(train_dataloader)val_acc, val_loss = evaluate(valid_dataloader)# 更新学习率的策略lr = optimizer.state_dict()['param_groups'][0]['lr']if total_accu is not None and total_accu > val_acc:scheduler.step()else:total_accu = val_accprint('-' * 69)print('| end of epoch {:3d} | time: {:4.2f}s | ''valid accuracy {:4.3f} | valid loss {:4.3f} | lr {:4.6f}'.format(epoch, time.time() - epoch_start_time, val_acc, val_loss, lr))print('-' * 69)test_acc, test_loss = evaluate(valid_dataloader)
print('模型的准确率: {:5.4f}'.format(test_acc))def predict(text, text_pipeline):with torch.no_grad():text = torch.tensor(text_pipeline(text))output = model(text, torch.tensor([0]))return output.argmax(1).item()# 示例文本字符串
# ex_text_str = "例句输入——这是一个待预测类别的示例句子"
ex_text_str = "这不仅影响到我们的方案是否可行13号的"model = model.to("cpu")print("该文本的类别是: %s" % label_name[predict(ex_text_str, text_pipeline)])
9.代码改进及优化
9.1优化器: 尝试不同的优化算法,如Adam、RMSprop替换原来的SGD优化器部分
9.1.1使用Adam优化器:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import os, PIL, pathlib, warningswarnings.filterwarnings("ignore") # 忽略警告信息# win10系统
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)import pandas as pd# 加载自定义中文数据集
train_data = pd.read_csv('D:/train.csv', sep='\t', header=None)
train_data.head()# 构建数据集迭代器
def custom_data_iter(texts, labels):for x, y in zip(texts, labels):yield x, ytrain_iter = custom_data_iter(train_data[0].values[:], train_data[1].values[:])from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import jieba# 中文分词方法
tokenizer = jieba.lcutdef yield_tokens(data_iter):for text,_ in data_iter:yield tokenizer(text)vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])print(vocab(['我', '想', '看', '书', '和', '你', '一起', '看', '电影', '的', '新款', '视频']))label_name = list(set(train_data[1].values[:]))
print(label_name)text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: label_name.index(x)print(text_pipeline('我想看新闻或者上网站看最新的游戏视频'))
print(label_pipeline('Video-Play'))from torch.utils.data import DataLoaderdef collate_batch(batch):label_list, text_list, offsets = [], [], [0]for (_text, _label) in batch:# 标签列表label_list.append(label_pipeline(_label))# 文本列表processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)text_list.append(processed_text)# 偏移量,即词汇的起始位置offsets.append(processed_text.size(0))label_list = torch.tensor(label_list, dtype=torch.int64)text_list = torch.cat(text_list)offsets = torch.tensor(offsets[:-1]).cumsum(dim=0) # 累计偏移量dim中维度元素的累计和return text_list.to(device), label_list.to(device), offsets.to(device)# 数据加载器,调用示例
dataloader = DataLoader(train_iter,batch_size=8,shuffle=False,collate_fn=collate_batch)from torch import nnclass TextClassificationModel(nn.Module):def __init__(self, vocab_size, embed_dim, num_class):super(TextClassificationModel, self).__init__()self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)self.fc = nn.Linear(embed_dim, num_class)self.init_weights()def init_weights(self):initrange = 0.5self.embedding.weight.data.uniform_(-initrange, initrange)self.fc.weight.data.uniform_(-initrange, initrange)self.fc.bias.data.zero_()def forward(self, text, offsets):embedded = self.embedding(text, offsets)return self.fc(embedded)
num_class = len(label_name)
vocab_size = len(vocab)
em_size = 64
model = TextClassificationModel(vocab_size, em_size, num_class).to(device)import timedef train(dataloader):model.train() # 切换到训练模式total_acc, train_loss, total_count = 0, 0, 0log_interval = 50start_time = time.time()for idx, (text, label, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)optimizer.zero_grad() # 梯度归零loss = criterion(predicted_label, label) # 计算损失loss.backward() # 反向传播torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) # 梯度裁剪optimizer.step() # 优化器更新权重# 记录acc和losstotal_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('| epoch {:3d} | {:5d}/{:5d} batches ''| accuracy {:8.3f} | loss {:8.5f}'.format(epoch, idx, len(dataloader),total_acc/total_count, train_loss/total_count))total_acc, train_loss, total_count = 0, 0, 0start_time = time.time()def evaluate(dataloader):model.eval() # 切换到评估模式total_acc, total_count = 0, 0with torch.no_grad():for idx, (text, label, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)loss = criterion(predicted_label, label) # 计算losstotal_acc += (predicted_label.argmax(1) == label).sum().item()total_count += label.size(0)return total_acc/total_count, total_countfrom torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
# 参数设置
EPOCHS = 10 # epoch数量
LR = 5 # 学习速率
BATCH_SIZE = 64 # 训练的batch大小# 设置损失函数、优化器和调度器
criterion = torch.nn.CrossEntropyLoss()
#optimizer = torch.optim.SGD(model.parameters(), lr=LR)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None# 准备数据集
train_iter = custom_data_iter(train_data[0].values[:], train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)split_train_, split_valid_ = random_split(train_dataset,[int(len(train_dataset)*0.8), int(len(train_dataset)*0.2)])train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)# 训练循环
for epoch in range(1, EPOCHS + 1):epoch_start_time = time.time()train(train_dataloader)val_acc, val_loss = evaluate(valid_dataloader)# 更新学习率的策略lr = optimizer.state_dict()['param_groups'][0]['lr']if total_accu is not None and total_accu > val_acc:scheduler.step()else:total_accu = val_accprint('-' * 69)print('| end of epoch {:3d} | time: {:4.2f}s | ''valid accuracy {:4.3f} | valid loss {:4.3f} | lr {:4.6f}'.format(epoch, time.time() - epoch_start_time, val_acc, val_loss, lr))print('-' * 69)test_acc, test_loss = evaluate(valid_dataloader)
print('模型的准确率: {:5.4f}'.format(test_acc))def predict(text, text_pipeline):with torch.no_grad():text = torch.tensor(text_pipeline(text))output = model(text, torch.tensor([0]))return output.argmax(1).item()# 示例文本字符串
# ex_text_str = "例句输入——这是一个待预测类别的示例句子"
ex_text_str = "这不仅影响到我们的方案是否可行13号的"model = model.to("cpu")print("该文本的类别是: %s" % label_name[predict(ex_text_str, text_pipeline)])
相关文章:
中文文本分类(pytorch 实现)
import torch import torch.nn as nn import torchvision from torchvision import transforms, datasets import os, PIL, pathlib, warningswarnings.filterwarnings("ignore") # 忽略警告信息# win10系统 device torch.device("cuda" if torch.cuda.i…...
【每日前端面经】2023-02-27
题目来源: 牛客 CSS盒模型 CSS中的盒子包括margin|border|padding|content四个部分,对于标准盒子模型(content-box)的widthcontent,但是对于IE盒子模型(border-box)的widthcontentborder2padding2 CSS选…...
springboot + easyRules 搭建规则引擎服务
依赖 <dependency><groupId>org.jeasy</groupId><artifactId>easy-rules-core</artifactId><version>4.0.0</version></dependency><dependency><groupId>org.jeasy</groupId><artifactId>easy-rules…...
Mac电脑配置环境变量
1.打开配置文件bash_profile open -e .bash_profile 2.如果没有创建过.bash_profile,则先需要创建 touch .bash_profile 3.输入你要配置的环境变量 #Setting PATH for Android ADB Tools export ANDROID_HOME/Users/xxx/android export PATH${PATH}:${ANDROID_HOME}…...
Windows系统x86机器安装(麒麟、统信)ARM系统详细教程
本次介绍在window系统x86机器上安装国产系统 arm 系统的详细教程。 注:ubuntu 的arm系统安装是一样的流程。 1.安装环境准备。 首先,你得有台电脑,配置别太差,至少4核8G内存,安装window10或者11都行(为啥…...
消息中间件篇之RabbitMQ-高可用机制
一、怎么保证高可用性 在生产环境下,使用集群来保证高可用性,一般我们采用普通集群、镜像集群、仲裁队列。 二、普通集群 普通集群,或者叫标准集群(classic cluster),具备下列特征: 1. 会在集…...
express+mysql+vue,从零搭建一个商城管理系统5--用户注册
提示:学习express,搭建管理系统 文章目录 前言一、新建user表二、安装bcryptjs、MD5、body-parser三、修改config/db.js四、新建config/bcrypt.js五、新建models文件夹和models/user.js五、index.js引入使用body-parser六、修改routes/user.js七、启动项…...
canvas水波纹效果,jquery鼠标水波纹插件
canvas水波纹效果,jquery鼠标水波纹插件 效果展示 jQuery水波纹效果,canvas水波纹插件 HTML代码片段 <div class"scroll04wrap"><h3>发展历程</h3><div class"scroll04"><p>不要回头,一…...
Zookeeper客户端命令、JAVA API、监听原理、写数据原理以及案例
1. Zookeeper节点信息 指定服务端,启动客户端命令: bin/zkCli.sh -server 服务端主机名:端口号 1)ls / 查看根节点下面的子节点 ls -s / 查看根节点下面的子节点以及根节点详细信息 其中,cZxid是创建节点的事务id,…...
[嵌入式系统-34]:RT-Thread -19- 新手指南:RT-Thread标准版系统架构
目录 一、RT-Thread 简介 二、RT-Thread 概述 三、许可协议 四、RT-Thread 的架构 4.1 内核层: 4.2 组件与服务层: 4.3 RT-Thread 软件包: 一、RT-Thread 简介 作为一名 RTOS 的初学者,也许你对 RT-Thread 还比较陌生。然…...
postman访问k8s api
第一种方式: kubectl -n kubesphere-system get sa kubesphere -oyaml apiVersion: v1 kind: ServiceAccount metadata:annotations:meta.helm.sh/release-name: ks-coremeta.helm.sh/release-namespace: kubesphere-systemcreationTimestamp: "2023-07-24T07…...
UE4c++ ConvertActorsToStaticMesh
UE4c ConvertActorsToStaticMesh ConvertActorsToStaticMesh UE4c ConvertActorsToStaticMesh创建Edior模块(最好是放Editor模块毕竟是编辑器代码)创建UBlueprintFunctionLibraryUTestFunctionLibrary.hUTestFunctionLibrary.cpp:.Build.cs 目标:为了大量…...
Qt中tableView控件的使用
tableView使用注意事项 tableView在使用时,从工具栏拖动到底层页面后,右键进行选择如下图所示: 此处需要注意的是,需要去修改属性,从UI上修改属性如下所示: 也可以通过代码修改属性: //将其设…...
【医学影像】LIDC-IDRI数据集的无痛制作
LIDC-IDRI数据集制作 0.下载0.0 链接汇总0.1 步骤 1.合成CT图reference 0.下载 0.0 链接汇总 LIDC-IDRI官方网址:https://www.cancerimagingarchive.net/nbia-search/?CollectionCriteriaLIDC-IDRINBIA Data Retriever 下载链接:https://wiki.canceri…...
MacOS开发环境搭建详解
搭建MacOS开发环境需要准备相应的软硬件,并遵循一系列步骤。以下是详细的步骤: 软硬件准备: MacOS电脑:确保你的电脑运行的是MacOS操作系统。Xcode软件:打开AppStore,搜索并安装Xcode。安装过程可能较长&…...
全量知识系统问题及SmartChat给出的答复 之2
Q6. 根据DDD的思想( 也就是借助 DDD的某个或某些实现),是否能按照这个想法给出程序设计和代码结构? 当使用领域驱动设计(DDD)的思想来设计程序和代码结构时,可以根据领域模型、领域服务、值对象、实体等概念来进行设计…...
嵌入式驱动学习第一周——vim的使用
前言 本篇博客学习使用vim,vim作为linux下的编辑器,学linux肯定是绕不开vim的,因为不确定对方环境中是否安装了编译器,但一定会有vim。 对于基本的使用只需要会打开文件,保存文件,编辑文件即可。 嵌入式驱动…...
loop_list单向循环列表
#include "loop_list.h" //创建单向循环链表 loop_p create_head() { loop_p L(loop_p)malloc(sizeof(loop_list)); if(LNULL) { printf("create fail\n"); return NULL; } L->len 0; L->nextL; retur…...
Python爬虫实战第二例【二】
零.前言: 本文章借鉴:Python爬虫实战(五):根据关键字爬取某度图片批量下载到本地(附上完整源码)_python爬虫下载图片-CSDN博客 大佬的文章里面有API的获取,在这里我就不赘述了。 一…...
Eclipse是如何创建web project项目的?
前面几篇描述先后描述了tomcat的目录结构和访问机制,以及Eclipse的项目类型和怎么调用jar包,还有java的main函数等,这些是一些基础问题,基础高清出来才更容易搞清楚后面要说的东西,也就是需求带动学习,后面…...
Python:操作 Excel 折叠
💖亲爱的技术爱好者们,热烈欢迎来到 Kant2048 的博客!我是 Thomas Kant,很开心能在CSDN上与你们相遇~💖 本博客的精华专栏: 【自动化测试】 【测试经验】 【人工智能】 【Python】 Python 操作 Excel 系列 读取单元格数据按行写入设置行高和列宽自动调整行高和列宽水平…...
页面渲染流程与性能优化
页面渲染流程与性能优化详解(完整版) 一、现代浏览器渲染流程(详细说明) 1. 构建DOM树 浏览器接收到HTML文档后,会逐步解析并构建DOM(Document Object Model)树。具体过程如下: (…...
MySQL中【正则表达式】用法
MySQL 中正则表达式通过 REGEXP 或 RLIKE 操作符实现(两者等价),用于在 WHERE 子句中进行复杂的字符串模式匹配。以下是核心用法和示例: 一、基础语法 SELECT column_name FROM table_name WHERE column_name REGEXP pattern; …...
智能仓储的未来:自动化、AI与数据分析如何重塑物流中心
当仓库学会“思考”,物流的终极形态正在诞生 想象这样的场景: 凌晨3点,某物流中心灯火通明却空无一人。AGV机器人集群根据实时订单动态规划路径;AI视觉系统在0.1秒内扫描包裹信息;数字孪生平台正模拟次日峰值流量压力…...
技术栈RabbitMq的介绍和使用
目录 1. 什么是消息队列?2. 消息队列的优点3. RabbitMQ 消息队列概述4. RabbitMQ 安装5. Exchange 四种类型5.1 direct 精准匹配5.2 fanout 广播5.3 topic 正则匹配 6. RabbitMQ 队列模式6.1 简单队列模式6.2 工作队列模式6.3 发布/订阅模式6.4 路由模式6.5 主题模式…...
C++课设:简易日历程序(支持传统节假日 + 二十四节气 + 个人纪念日管理)
名人说:路漫漫其修远兮,吾将上下而求索。—— 屈原《离骚》 创作者:Code_流苏(CSDN)(一个喜欢古诗词和编程的Coder😊) 专栏介绍:《编程项目实战》 目录 一、为什么要开发一个日历程序?1. 深入理解时间算法2. 练习面向对象设计3. 学习数据结构应用二、核心算法深度解析…...
解读《网络安全法》最新修订,把握网络安全新趋势
《网络安全法》自2017年施行以来,在维护网络空间安全方面发挥了重要作用。但随着网络环境的日益复杂,网络攻击、数据泄露等事件频发,现行法律已难以完全适应新的风险挑战。 2025年3月28日,国家网信办会同相关部门起草了《网络安全…...
关于easyexcel动态下拉选问题处理
前些日子突然碰到一个问题,说是客户的导入文件模版想支持部分导入内容的下拉选,于是我就找了easyexcel官网寻找解决方案,并没有找到合适的方案,没办法只能自己动手并分享出来,针对Java生成Excel下拉菜单时因选项过多导…...
在 Spring Boot 中使用 JSP
jsp? 好多年没用了。重新整一下 还费了点时间,记录一下。 项目结构: pom: <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://ww…...
前端高频面试题2:浏览器/计算机网络
本专栏相关链接 前端高频面试题1:HTML/CSS 前端高频面试题2:浏览器/计算机网络 前端高频面试题3:JavaScript 1.什么是强缓存、协商缓存? 强缓存: 当浏览器请求资源时,首先检查本地缓存是否命中。如果命…...
