当前位置: 首页 > article >正文

从代码学习深度学习 - LSTM PyTorch版

文章目录

  • 前言
  • 一、数据加载与预处理
    • 1.1 代码实现
    • 1.2 功能解析
  • 二、LSTM介绍
    • 2.1 LSTM原理
    • 2.2 模型定义
      • 代码解析
  • 三、训练与预测
    • 3.1 训练逻辑
      • 代码解析
    • 3.2 可视化工具
      • 功能解析
      • 功能结果
  • 总结


前言

深度学习中的循环神经网络(RNN)及其变种长短期记忆网络(LSTM)在处理序列数据(如文本、时间序列等)方面表现出色。本篇博客将通过一个完整的PyTorch实现,带你从零开始学习如何使用LSTM进行文本生成任务。我们将基于H.G. Wells的《时间机器》数据集,逐步展示数据预处理、模型定义、训练与预测的全过程。通过代码和文字的结合,帮助你深入理解LSTM的实现细节及其在自然语言处理中的应用。

本文的代码分为四个主要部分:

  1. 数据加载与预处理(utils_for_data.py
  2. LSTM模型定义(Jupyter Notebook中的模型部分)
  3. 训练与预测逻辑(utils_for_train.py
  4. 可视化工具(utils_for_huitu.py

以下是详细的实现与解析。


一、数据加载与预处理

首先,我们需要加载《时间机器》数据集并进行预处理。以下是utils_for_data.py中的完整代码及其功能说明。

1.1 代码实现

import random
import re
import torch
from collections import Counterdef read_time_machine():"""将时间机器数据集加载到文本行的列表中"""with open('timemachine.txt', 'r') as f:lines = f.readlines()return [re.sub('[^A-Za-z]+', ' ', line).strip().lower() for line in lines]def tokenize(lines, token='word'):"""将文本行拆分为单词或字符词元"""if token == 'word':return [line.split() for line in lines]elif token == 'char':return [list(line) for line in lines]else:print(f'错误:未知词元类型:{token}')def count_corpus(tokens):"""统计词元的频率"""if not tokens:return Counter()if isinstance(tokens[0], list):flattened_tokens = [token for sublist in tokens for token in sublist]else:flattened_tokens = tokensreturn Counter(flattened_tokens)class Vocab:"""文本词表类,用于管理词元及其索引的映射关系"""def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):self.tokens = tokens if tokens is not None else []self.reserved_tokens = reserved_tokens if reserved_tokens is not None else []counter = self._count_corpus(self.tokens)self._token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)self.idx_to_token = ['<unk>'] + self.reserved_tokensself.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)}for token, freq in self._token_freqs:if freq < min_freq:breakif token not in self.token_to_idx:self.idx_to_token.append(token)self.token_to_idx[token] = len(self.idx_to_token) - 1@staticmethoddef _count_corpus(tokens):if not tokens:return Counter()if isinstance(tokens[0], list):tokens = [token for sublist in tokens for token in sublist]return Counter(tokens)def __len__(self):return len(self.idx_to_token)def __getitem__(self, tokens):if not isinstance(tokens, (list, tuple)):return self.token_to_idx.get(tokens, self.unk)return [self[token] for token in tokens]def to_tokens(self, indices):if not isinstance(indices, (list, tuple)):return self.idx_to_token[indices]return [self.idx_to_token[index] for index in indices]@propertydef unk(self):return 0@propertydef token_freqs(self):return self._token_freqsdef load_corpus_time_machine(max_tokens=-1):lines = read_time_machine()tokens = tokenize(lines, 'char')vocab = Vocab(tokens)corpus = [vocab[token] for line in tokens for token in line]if max_tokens > 0:corpus = corpus[:max_tokens]return corpus, vocabdef seq_data_iter_random(corpus, batch_size, num_steps):offset = random.randint(0, num_steps - 1)corpus = corpus[offset:]num_subseqs = (len(corpus) - 1) // num_stepsinitial_indices = list(range(0, num_subseqs * num_steps, num_steps))random.shuffle(initial_indices)def data(pos):return corpus[pos:pos + num_steps]num_batches = num_subseqs // batch_sizefor i in range(0, batch_size * num_batches, batch_size):initial_indices_per_batch = initial_indices[i:i + batch_size]X = [data(j) for j in initial_indices_per_batch]Y = [data(j + 1) for j in initial_indices_per_batch]yield torch.tensor(X), torch.tensor(Y)def seq_data_iter_sequential(corpus, batch_size, num_steps):offset = random.randint(0, num_steps)num_tokens = ((len(corpus) - offset - 1) // batch_size) *

相关文章:

从代码学习深度学习 - LSTM PyTorch版

文章目录 前言一、数据加载与预处理1.1 代码实现1.2 功能解析二、LSTM介绍2.1 LSTM原理2.2 模型定义代码解析三、训练与预测3.1 训练逻辑代码解析3.2 可视化工具功能解析功能结果总结前言 深度学习中的循环神经网络(RNN)及其变种长短期记忆网络(LSTM)在处理序列数据(如文…...

大数据技术发展与应用趋势分析

大数据技术发展与应用趋势分析 文章目录 大数据技术发展与应用趋势分析1. 大数据概述2 大数据技术架构2.1 数据采集层2.2 数据存储层2.3 数据处理层2.4 数据分析层 3 大数据发展趋势3.1 AI驱动的分析与自动化3.2 隐私保护分析技术3.3 混合云架构的普及3.4 数据网格架构3.5 量子…...

与Linux操作系统相关的引导和服务

目录 一.Linux操作系统引导过程 1.1引导过程总览 1.2系统初始化进程 1.2.1init进程 1.2.2sysmted 1.3systemd单元类型 二.排除启动类故障 2.1MBR扇区故障 2.1.1故障原因 2.1.2故障现象 2.1.3解决办法 2.1.4模拟修复MBR扇区故障 1)添加新的硬盘 2&#xff09;进行…...

STM32单片机入门学习——第16节: [6-4] PWM驱动LED呼吸灯PWM驱动舵机PWM驱动直流电机

写这个文章是用来学习的,记录一下我的学习过程。希望我能一直坚持下去,我只是一个小白,只是想好好学习,我知道这会很难&#xff0c;但我还是想去做&#xff01; 本文写于&#xff1a;2025.04.05 STM32开发板学习——第16节: [6-4] PWM驱动LED呼吸灯&PWM驱动舵机&PWM驱…...

基础框架系列分享:一个通用的Excel报表生成管理框架

由于我们系统经常要生成大量的Excel报表&#xff08;Word&#xff0c;PDF报表也有&#xff0c;另行分享&#xff09;&#xff0c;最初始他们的方案是&#xff0c;设计一个表&#xff0c;和Excel完全对应&#xff0c;然后读表&#xff0c;把数据填进去&#xff0c;这显然是非常不…...

Ansible(4)—— Playbook

目录 一、Ansible Playbook : 1、Play &#xff1a; 2、Playbook&#xff1a; 二、Ansible Playbook 格式&#xff1a; 1、空格&#xff1a; 2、破折号&#xff08; - &#xff09;&#xff1a; 3、Play 格式&#xff1a; 三、查找用于任务的模块&#xff1a; 1、模块…...

自学-C语言-基础-数组、函数、指针、结构体和共同体、文件

这里写自定义目录标题 代码环境&#xff1a;&#xff1f;问题思考&#xff1a;一、数组二、函数三、指针四、结构体和共同体五、文件问题答案&#xff1a; 代码环境&#xff1a; Dev C &#xff1f;问题思考&#xff1a; 把上门的字母与下面相同的字母相连&#xff0c;线不能…...

Bash 花括号扩展 {start..end} 进阶使用指南——字典生成

Bash 的花括号扩展&#xff08;brace expansion&#xff09;{start..end} 是一个强大而灵活的语法特性&#xff0c;用于生成特定序列或组合。它在脚本编写、爆破字典生成、文件批量操作以及模式匹配中有着广泛的应用。本文将从基础用法到高级技巧&#xff0c;带你全面掌握这一功…...

AGI大模型(10):prompt逆向-巧借prompt

1 提示词逆向 明确逆向提示词⼯程概念 我们可以给ChatGPT提供⼀个简洁的提示词,让它能够更准确地理解我们所讨论的“逆向提示词⼯程”是什么意思,并通过这个思考过程,帮它将相关知识集中起来,进⽽构建⼀个专业的知识领域 提示词:请你举⼀个简单的例⼦,解释⼀下逆向pro…...

蓝桥云客--团队赛

2.团队赛【算法赛】 - 蓝桥云课 问题描述 蓝桥杯最近推出了一项团队赛模式&#xff0c;要求三人组队参赛&#xff0c;并规定其中一人必须担任队长。队长的资格很简单&#xff1a;其程序设计能力值必须严格大于其他两名队友程序设计能力值的总和。 小蓝、小桥和小杯正在考虑报名…...

C-S模式之实现一对一聊天

天天开心&#xff01;&#xff01;&#xff01; 文章目录 一、如何实现一对一聊天&#xff1f;1. 服务器设计2. 客户端设计3. 服务端代码实现4. 客户端代码实现5. 实现说明6.实验结果 二、改进常见的服务器高并发方案1. 多线程/多进程模型2. I/O多路复用3. 异步I/O&#xff08;…...

[Deep-ML]Transpose of a Matrix(矩阵的转置)

Transpose of a Matrix&#xff08;矩阵的转置&#xff09; 题目链接&#xff1a; Transpose of a Matrix&#xff08;矩阵的转置&#xff09;https://www.deep-ml.com/problems/2 题目描述&#xff1a; 难度&#xff1a; easy&#xff08;简单&#xff09;。 分类&#…...

Java的Selenium的特殊元素操作与定位之select下拉框

如果页面元素是一个下拉框&#xff0c;我们可以将此web元素封装为Select对象 Select selectnew Select(WebElement element); Select对象常用api select.getOptions();//获取所有选项select.selectBylndex(index);//根据索引选中对应的元素select.selectByValue(value);//选…...

前端精度计算:Decimal.js 基本用法与详解

一、Decimal.js 简介 decimal.js 是一个用于任意精度算术运算的 JavaScript 库&#xff0c;它可以完美解决浮点数计算中的精度丢失问题。 官方API文档&#xff1a;Decimal.js 特性&#xff1a; 任意精度计算&#xff1a;支持大数、小数的高精度运算。 链式调用&#xff1a;…...

智慧节能双突破 强力巨彩谷亚VK系列刷新LED屏使用体验

当前全球节能减排趋势明显&#xff0c;LED节能屏作为显示技术的佼佼者&#xff0c;正逐渐成为市场的新宠。强力巨彩谷亚万境VK系列节能智慧屏凭借三重技术保障、四大智能设计以及大师臻彩画质&#xff0c;在实现节能效果的同时&#xff0c;更在智慧显示领域树立新的标杆。   …...

html 给文本两端加虚线自适应

效果图&#xff1a; <div class"separator">文本 </div>.separator {width: 40%;border-style: dashed;display: flex;align-items: center;color: #e2e2e2;font-size: 14px;line-height: 20px;border-color: #e2e2e2;border-width: 0; }.separator::bef…...

C#:is关键字

目录 is 关键字的核心是什么&#xff1f; 1. 什么是 is 关键字&#xff0c;为什么要用它&#xff1f; 2. 如何使用 is 关键字&#xff1f; 3. is 的作用和场景 4. is 与 as 的区别 5. 模式匹配的扩展&#xff08;C# 8.0&#xff09; 6. 常见陷阱和注意事项 总结&#x…...

leetcode4.寻找两个正序数组中的中位数

思路源于 LeetCode004-两个有序数组的中位数-最优算法代码讲解 基本思路是将两个数组看成一个数组&#xff0c;然后划分为两个部分&#xff0c;若为奇数左边部分个数多1&#xff0c;若为偶数左边部分等于右边部分个数。i表示数组1划分位置&#xff08;i为4是索引4也表示i的左半…...

0101安装matplotlib_numpy_pandas-报错-python

文章目录 1 前言2 报错报错1&#xff1a;ModuleNotFoundError: No module named distutils报错2&#xff1a;ERROR:root:code for hash blake2b was not found.报错3&#xff1a;**ModuleNotFoundError: No module named _tkinter**报错4&#xff1a;UserWarning: Glyph 39044 …...

Qt之QHostInfo

简介 QHostInfo表示主机信息&#xff0c;即主机名称 常用接口 static QHostInfo fromName(const QString &name); QString hostName() const; QList<QHostAddress> addresses() const;结构 #mermaid-svg-HTJ95sEk8JwO4uCy {font-family:"trebuchet ms",…...

OSCP - Proving Grounds- SoSimple

主要知识点 wordpress 插件RCE漏洞sudo -l shell劫持 具体步骤 依旧是nmap 起手&#xff0c;只发现了22和80端口&#xff0c;但80端口只能看到一张图 Nmap scan report for 192.168.214.78 Host is up (0.46s latency). Not shown: 65533 closed tcp ports (reset) PORT …...

JVM虚拟机篇(五):深入理解Java类加载器与类加载机制

深入理解Java类加载器与类加载机制 深入理解Java类加载器与类加载机制一、引言二、类加载器2.1 类加载器的定义2.2 类加载器的分类2.2.1 启动类加载器&#xff08;Bootstrap ClassLoader&#xff09;2.2.2 扩展类加载器&#xff08;Extension ClassLoader&#xff09;2.2.3 应用…...

Golang的Goroutine(协程)与runtime

目录 Runtime 包概述 Runtime 包常用函数 1. GOMAXPROCS 2. Caller 和 Callers 3. BlockProfile 和 Stack 理解Golang的Goroutine Goroutine的基本概念 特点&#xff1a; Goroutine的创建与启动 示例代码 解释 Goroutine的调度 Gosched的作用 示例代码 输出 解…...

C语言求3到100之间的素数

一、代码展示 二、运行结果 三、感悟思考 注意: 这个题思路他是一个试除法的一个思路 先进入一个for循环 遍历3到100之间的数字 第二个for循环则是 判断他不是素数 那么就直接退出 这里用break 是素数就打印出来 在第一个for循环内 第二个for循环外...

【2025】物联网发展趋势介绍

目录 物联网四层架构感知识别层网络构建层管理服务层——**边缘存储**边缘计算关键技术&#xff1a;综合应用层——信息应用 物联网四层架构 综合应用层&#xff1a;信息应用 利用获取的信息和知识&#xff0c;支持各类应用系统的运转 管理服务层&#xff1a;信息处理 对数据进…...

如何查看 MySQL 的磁盘空间使用情况:从表级到数据库级的分析

在日常数据库管理中&#xff0c;了解每张表和每个数据库占用了多少磁盘空间是非常关键的。这不仅有助于我们监控数据增长&#xff0c;还能为性能优化提供依据。 Google Gemini中国版调用Google Gemini API&#xff0c;中国大陆优化&#xff0c;完全免费&#xff01;https://ge…...

AI平台初步规划实现和想法

要实现一个类似Coze的工作流搭建引擎&#xff0c;可以结合SmartEngine作为后端工作流引擎&#xff0c;ReactFlow作为前端流程图渲染工具&#xff0c;以及Ant Design作为UI组件库。以下是实现的步骤和关键点&#xff1a; ### 1. 后端工作流引擎&#xff08;SmartEngine&#xf…...

ARXML文件解析-2

目录 1 摘要2 常见ARXML文件注意事项以及常见问题2.1 注意事项2.2 常见问题2.3 答疑 3 ARXML解读/编辑指南3.1 解读ARXML文件的步骤3.2 编辑ARXML文件的方法3.3 验证与调试 4 总结 1 摘要 本文主要对ARXML文件的注意事项、常见问题以及解读与编辑进行详细介绍。 上文回顾&…...

汇编学习之《移位指令》

这章节学习前需要回顾之前的标志寄存器的内容&#xff1a; 汇编学习之《标志寄存器》 算数移位指令 SAL (Shift Arithmetic Left)算数移位指令 : 左移一次&#xff0c;最低位用0补位&#xff0c;最高位放入EFL标志寄存器的CF位&#xff08;进位标志&#xff09; OllyDbg查看…...

Nature Communications上交、西湖大学、复旦大学研发面向机器人多模式运动的去电子化刚弹耦合高频自振荡驱动单元

近年来&#xff0c;轻型仿生机器人因其卓越的运动灵活性与环境适应性受到国际机器人领域的广泛关注。然而&#xff0c;现有气动驱动器普遍受限于低模量粘弹性材料的回弹滞后效应与能量耗散特性&#xff0c;加之其"非刚即柔"的二元结构设计范式&#xff0c;难以同时满…...