自定义数据集,使用 PyTorch 框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测
在本文中,我们将展示如何使用 NumPy 创建自定义数据集,利用 PyTorch 实现一个简单的逻辑回归模型,并在训练完成后保存该模型,最后加载模型并用它进行预测。
1. 创建自定义数据集
首先,我们使用 NumPy 创建一个简单的二分类数据集。假设我们的数据集包含两个特征。
import numpy as np# 生成随机数据
np.random.seed(42)
X = np.random.randn(100, 2) # 100个样本,2个特征
y = (X[:, 0] + X[:, 1] > 0).astype(int) # 标签为1或0# 打印数据集的前5个样本
print(X[:5], y[:5])
2. 构建逻辑回归模型
接下来,我们使用 PyTorch 来构建一个简单的逻辑回归模型。PyTorch 提供了 torch.nn.Module
类,能够轻松实现神经网络模型。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset# 转换 NumPy 数据为 PyTorch 张量
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32).view(-1, 1)# 创建数据集并加载
dataset = TensorDataset(X_tensor, y_tensor)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)# 定义逻辑回归模型
class LogisticRegressionModel(nn.Module):def __init__(self):super(LogisticRegressionModel, self).__init__()self.linear = nn.Linear(2, 1) # 2个输入特征,1个输出def forward(self, x):return torch.sigmoid(self.linear(x))# 初始化模型
model = LogisticRegressionModel()
3. 训练模型
接下来,我们定义损失函数和优化器,并训练模型。
# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练模型
epochs = 1000
for epoch in range(epochs):for inputs, labels in dataloader:optimizer.zero_grad() # 清空梯度outputs = model(inputs) # 前向传播loss = criterion(outputs, labels) # 计算损失loss.backward() # 反向传播optimizer.step() # 更新权重if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
4. 保存模型
模型训练完成后,我们可以将模型保存到文件中。PyTorch 提供了 torch.save()
方法来保存模型的状态字典。
# 保存模型
torch.save(model.state_dict(), 'logistic_regression_model.pth')
print("模型已保存!")
5. 加载模型并进行预测
我们可以加载保存的模型,并对新数据进行预测。
# 加载模型
loaded_model = LogisticRegressionModel()
loaded_model.load_state_dict(torch.load('logistic_regression_model.pth'))
loaded_model.eval() # 切换到评估模式# 进行预测
with torch.no_grad():test_data = torch.tensor([[1.5, -0.5]], dtype=torch.float32)prediction = loaded_model(test_data)print(f'预测值: {prediction.item():.4f}')
6. 总结
在这篇博客中,我们展示了如何使用 NumPy 创建一个简单的自定义数据集,并使用 PyTorch 实现一个逻辑回归模型。我们还展示了如何保存训练好的模型,并加载模型进行预测。通过保存和加载模型,我们可以在不同的时间或环境中重复使用已经训练好的模型,而不需要重新训练它。
相关文章:
自定义数据集,使用 PyTorch 框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测
在本文中,我们将展示如何使用 NumPy 创建自定义数据集,利用 PyTorch 实现一个简单的逻辑回归模型,并在训练完成后保存该模型,最后加载模型并用它进行预测。 1. 创建自定义数据集 首先,我们使用 NumPy 创建一个简单的…...
UE5 特效
能帮到你的话,就给个赞吧 😘 文章目录 post processexposurebloomvignettesaturationunbound material材质蓝图alt z base colorconstant3Vector roughnessconstant metallicconstant pbrroughnessmetallicmake more realmake some areas rougher than o…...
CMAKE工程编译好后自动把可执行文件传输到远程开发板
# 设置 CMake 最低版本要求 cmake_minimum_required(VERSION 3.10)# 设置项目名称 project(MyProject)# 添加可执行文件,这里以项目名作为可执行文件的名称 add_executable(${PROJECT_NAME} main.cpp)# 设置开发板信息 set(DEVELOPMENT_BOARD_IP "192.168.1.10…...

Windows 程序设计7:文件的创建、打开与关闭
文章目录 前言一、文件的创建与打开CreateFile1. 创建新的空白文件2. 打开已存在文件3. 打开一个文件时,如果文件存在则打开,如果文件不存在则新创建文件4.打开一个文件,如果文件存在则打开文件并清空内容,文件不存在则 新创建文件…...
策略模式 - 策略模式的使用
引言 在软件开发中,设计模式是解决常见问题的经典解决方案。策略模式(Strategy Pattern)是行为型设计模式之一,它允许在运行时选择算法的行为。通过将算法封装在独立的类中,策略模式使得算法可以独立于使用它的客户端…...

具身智能研究报告
参考: (1)GTC大会&Figure:“具身智能”奇点已至 (2)2024中国具身智能创投报告 (3)2024年具身智能产业发展研究报告 (4)具身智能行业深度:发展…...

Windows安装Milvus
安装Milvus 安装Docker前置条件: 安装Mlivus方案一方案二 Attu管理端 安装Docker 系统:Windows 11 家庭中文版 Mlivus:V2.3.0 Attu: V2.3.10 前置条件: 启用“适用于 Linux 的 Windows 子系统”可选功能,才能在 Win…...

Excel分区间统计分析(等步长、不等步长、多维度)
在数据分析过程中,可能会需要统计不同数据区间的人数、某个数据区间的平均值或者进行分组区间统计,本文从excel函数到数据透视表的方法,从简单需求到复杂需求,采用不同的方法进行讲解,尤其是通过数据透视表的强大功能大…...

宝塔mysql数据库容量限制_宝塔数据库mysql-bin.000001占用磁盘空间过大
磁盘空间占用过多,排查后发现网站/www/wwwroot只占用7G,/www/server占用却高达8G,再深入排查发现/www/server/data目录下的mysql-bin.000001和mysql-bin.000002两个日志文件占去了1.5G空间。 百度后学到以下知识,做个记录。 mysql…...
LeetCode 2412.完成所有交易的初始最少钱数:【年度巨献】举例说明(讲明白),由难至简(手脚不乱),附Python一行版
【LetMeFly】2412.完成所有交易的初始最少钱数:【年度巨献】举例说明(讲明白),由难至简(手脚不乱),附Python一行版 文章目录 【LetMeFly】2412.完成所有交易的初始最少钱数:【年度巨献】举例说明(讲明白),由难至简(手脚…...

多人-多agent协同可能会挑战维纳的反馈
在多人-多Agent协同系统中,维纳的经典反馈机制将面临新的挑战,而协同过程中的“算计”(策略性决策与协调)成为实现高效协作的核心。 1、非线性与动态性 维纳的反馈理论(尤其是在控制理论中)通常假设系统的动…...

Go学习:类型转换需注意的点 以及 类型别名
目录 1. 类型转换 2. 类型别名 1. 类型转换 在从前的学习中,知道布尔bool类型变量只有两种值true或false,C/C、Python、JAVA等编程语言中,如果将布尔类型bool变量转换为整型int变量,通常采用 “0为假,非0为真”的方…...
C语言中的局部变量和全局变量有什么区别?
在C语言中,局部变量和全局变量是两种具有不同作用域和存储期的变量。以下是它们之间的主要区别: 作用域 局部变量: 局部变量是在函数内部声明的变量。它们的作用域仅限于声明它们的函数内部。一旦函数执行完毕,局部变量就会超出…...
价值交换到底在交换什么
有人十多岁就很清醒,知道自己想要什么,要付出什么。有人20多岁清醒了,有人30多岁都不一定明白。 价值交换,四个字其实就可以解释大部分事情。价值交换和努力工作,勤劳没有任何关系。甚至努力和成功都不存在关系。 价值…...

C++传送锚点的内存寻址:内存管理
文章目录 1.C/C内存分布回顾2.C内存管理2.1 内存申请2.2 operator new与operator delete函数2.3 定位new表达式 3.关于内存管理的常见知识点3.1 malloc/free和new/delete的区别3.2 内存泄漏 希望读者们多多三连支持小编会继续更新你们的鼓励就是我前进的动力! 继C语…...

Prompt提示词完整案例:让chatGPT成为“书单推荐”的高手
大家好,我是老六哥,我正在共享使用AI提高工作效率的技巧。欢迎关注我,共同提高使用AI的技能,让AI成功你的个人助理。 许多人可能会跟老六哥一样,有过这样的体验:当我们遇到一个能力出众或对事物有独到见解的…...

基于django的智能停车场车辆管理深度学习车牌识别系统
完整源码项目包获取→点击文章末尾名片!...

【Proteus仿真】【51单片机】简易计算器系统设计
目录 一、主要功能 二、使用步骤 三、硬件资源 四、软件设计 五、实验现象 联系作者 一、主要功能 1、LCD1602液晶显示 2、矩阵按键 3、可以进行简单的加减乘除运算 4、最大 9999*9999 二、使用步骤 系统运行后,LCD1602显示数据,通过矩阵按键…...

洛谷P3884 [JLOI2009] 二叉树问题(详解)c++
题目链接:P3884 [JLOI2009] 二叉树问题 - 洛谷 | 计算机科学教育新生态 1.题目解析 1:从8走向6的最短路径,向根节点就是向上走,从8到1会经过三条边,向叶节点就是向下走,从1走到6需要经过两条边,…...
《Foundation 起步》
《Foundation 起步》 引言 Foundation 是一个广泛使用的开源前端框架,由 ZURB 团队创建。它旨在帮助开发者构建响应式、可访问性和移动优先的网页。本文将为您提供一个全面的指南,帮助您从零开始学习并使用 Foundation。 Foundation 简介 什么是 Foundation? Foundatio…...
QMC5883L的驱动
简介 本篇文章的代码已经上传到了github上面,开源代码 作为一个电子罗盘模块,我们可以通过I2C从中获取偏航角yaw,相对于六轴陀螺仪的yaw,qmc5883l几乎不会零飘并且成本较低。 参考资料 QMC5883L磁场传感器驱动 QMC5883L磁力计…...
Java如何权衡是使用无序的数组还是有序的数组
在 Java 中,选择有序数组还是无序数组取决于具体场景的性能需求与操作特点。以下是关键权衡因素及决策指南: ⚖️ 核心权衡维度 维度有序数组无序数组查询性能二分查找 O(log n) ✅线性扫描 O(n) ❌插入/删除需移位维护顺序 O(n) ❌直接操作尾部 O(1) ✅内存开销与无序数组相…...
【磁盘】每天掌握一个Linux命令 - iostat
目录 【磁盘】每天掌握一个Linux命令 - iostat工具概述安装方式核心功能基础用法进阶操作实战案例面试题场景生产场景 注意事项 【磁盘】每天掌握一个Linux命令 - iostat 工具概述 iostat(I/O Statistics)是Linux系统下用于监视系统输入输出设备和CPU使…...
【HTML-16】深入理解HTML中的块元素与行内元素
HTML元素根据其显示特性可以分为两大类:块元素(Block-level Elements)和行内元素(Inline Elements)。理解这两者的区别对于构建良好的网页布局至关重要。本文将全面解析这两种元素的特性、区别以及实际应用场景。 1. 块元素(Block-level Elements) 1.1 基本特性 …...

自然语言处理——循环神经网络
自然语言处理——循环神经网络 循环神经网络应用到基于机器学习的自然语言处理任务序列到类别同步的序列到序列模式异步的序列到序列模式 参数学习和长程依赖问题基于门控的循环神经网络门控循环单元(GRU)长短期记忆神经网络(LSTM)…...
LeetCode - 199. 二叉树的右视图
题目 199. 二叉树的右视图 - 力扣(LeetCode) 思路 右视图是指从树的右侧看,对于每一层,只能看到该层最右边的节点。实现思路是: 使用深度优先搜索(DFS)按照"根-右-左"的顺序遍历树记录每个节点的深度对于…...

push [特殊字符] present
push 🆚 present 前言present和dismiss特点代码演示 push和pop特点代码演示 前言 在 iOS 开发中,push 和 present 是两种不同的视图控制器切换方式,它们有着显著的区别。 present和dismiss 特点 在当前控制器上方新建视图层级需要手动调用…...

【C++进阶篇】智能指针
C内存管理终极指南:智能指针从入门到源码剖析 一. 智能指针1.1 auto_ptr1.2 unique_ptr1.3 shared_ptr1.4 make_shared 二. 原理三. shared_ptr循环引用问题三. 线程安全问题四. 内存泄漏4.1 什么是内存泄漏4.2 危害4.3 避免内存泄漏 五. 最后 一. 智能指针 智能指…...

【Linux系统】Linux环境变量:系统配置的隐形指挥官
。# Linux系列 文章目录 前言一、环境变量的概念二、常见的环境变量三、环境变量特点及其相关指令3.1 环境变量的全局性3.2、环境变量的生命周期 四、环境变量的组织方式五、C语言对环境变量的操作5.1 设置环境变量:setenv5.2 删除环境变量:unsetenv5.3 遍历所有环境…...
在 Spring Boot 项目里,MYSQL中json类型字段使用
前言: 因为程序特殊需求导致,需要mysql数据库存储json类型数据,因此记录一下使用流程 1.java实体中新增字段 private List<User> users 2.增加mybatis-plus注解 TableField(typeHandler FastjsonTypeHandler.class) private Lis…...