当前位置: 首页 > news >正文

自定义数据集,使用 PyTorch 框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测

在本文中,我们将展示如何使用 NumPy 创建自定义数据集,利用 PyTorch 实现一个简单的逻辑回归模型,并在训练完成后保存该模型,最后加载模型并用它进行预测。

1. 创建自定义数据集

首先,我们使用 NumPy 创建一个简单的二分类数据集。假设我们的数据集包含两个特征。

import numpy as np# 生成随机数据
np.random.seed(42)
X = np.random.randn(100, 2)  # 100个样本,2个特征
y = (X[:, 0] + X[:, 1] > 0).astype(int)  # 标签为1或0# 打印数据集的前5个样本
print(X[:5], y[:5])

2. 构建逻辑回归模型

接下来,我们使用 PyTorch 来构建一个简单的逻辑回归模型。PyTorch 提供了 torch.nn.Module 类,能够轻松实现神经网络模型。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset# 转换 NumPy 数据为 PyTorch 张量
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32).view(-1, 1)# 创建数据集并加载
dataset = TensorDataset(X_tensor, y_tensor)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)# 定义逻辑回归模型
class LogisticRegressionModel(nn.Module):def __init__(self):super(LogisticRegressionModel, self).__init__()self.linear = nn.Linear(2, 1)  # 2个输入特征,1个输出def forward(self, x):return torch.sigmoid(self.linear(x))# 初始化模型
model = LogisticRegressionModel()

3. 训练模型

接下来,我们定义损失函数和优化器,并训练模型。

# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练模型
epochs = 1000
for epoch in range(epochs):for inputs, labels in dataloader:optimizer.zero_grad()  # 清空梯度outputs = model(inputs)  # 前向传播loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新权重if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

4. 保存模型

模型训练完成后,我们可以将模型保存到文件中。PyTorch 提供了 torch.save() 方法来保存模型的状态字典。

# 保存模型
torch.save(model.state_dict(), 'logistic_regression_model.pth')
print("模型已保存!")

5. 加载模型并进行预测

我们可以加载保存的模型,并对新数据进行预测。

# 加载模型
loaded_model = LogisticRegressionModel()
loaded_model.load_state_dict(torch.load('logistic_regression_model.pth'))
loaded_model.eval()  # 切换到评估模式# 进行预测
with torch.no_grad():test_data = torch.tensor([[1.5, -0.5]], dtype=torch.float32)prediction = loaded_model(test_data)print(f'预测值: {prediction.item():.4f}')

6. 总结

在这篇博客中,我们展示了如何使用 NumPy 创建一个简单的自定义数据集,并使用 PyTorch 实现一个逻辑回归模型。我们还展示了如何保存训练好的模型,并加载模型进行预测。通过保存和加载模型,我们可以在不同的时间或环境中重复使用已经训练好的模型,而不需要重新训练它。

相关文章:

自定义数据集,使用 PyTorch 框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测

在本文中,我们将展示如何使用 NumPy 创建自定义数据集,利用 PyTorch 实现一个简单的逻辑回归模型,并在训练完成后保存该模型,最后加载模型并用它进行预测。 1. 创建自定义数据集 首先,我们使用 NumPy 创建一个简单的…...

UE5 特效

能帮到你的话,就给个赞吧 😘 文章目录 post processexposurebloomvignettesaturationunbound material材质蓝图alt z base colorconstant3Vector roughnessconstant metallicconstant pbrroughnessmetallicmake more realmake some areas rougher than o…...

CMAKE工程编译好后自动把可执行文件传输到远程开发板

# 设置 CMake 最低版本要求 cmake_minimum_required(VERSION 3.10)# 设置项目名称 project(MyProject)# 添加可执行文件,这里以项目名作为可执行文件的名称 add_executable(${PROJECT_NAME} main.cpp)# 设置开发板信息 set(DEVELOPMENT_BOARD_IP "192.168.1.10…...

Windows 程序设计7:文件的创建、打开与关闭

文章目录 前言一、文件的创建与打开CreateFile1. 创建新的空白文件2. 打开已存在文件3. 打开一个文件时,如果文件存在则打开,如果文件不存在则新创建文件4.打开一个文件,如果文件存在则打开文件并清空内容,文件不存在则 新创建文件…...

策略模式 - 策略模式的使用

引言 在软件开发中,设计模式是解决常见问题的经典解决方案。策略模式(Strategy Pattern)是行为型设计模式之一,它允许在运行时选择算法的行为。通过将算法封装在独立的类中,策略模式使得算法可以独立于使用它的客户端…...

具身智能研究报告

参考: (1)GTC大会&Figure:“具身智能”奇点已至 (2)2024中国具身智能创投报告 (3)2024年具身智能产业发展研究报告 (4)具身智能行业深度:发展…...

Windows安装Milvus

安装Milvus 安装Docker前置条件: 安装Mlivus方案一方案二 Attu管理端 安装Docker 系统:Windows 11 家庭中文版 Mlivus:V2.3.0 Attu: V2.3.10 前置条件: 启用“适用于 Linux 的 Windows 子系统”可选功能,才能在 Win…...

Excel分区间统计分析(等步长、不等步长、多维度)

在数据分析过程中,可能会需要统计不同数据区间的人数、某个数据区间的平均值或者进行分组区间统计,本文从excel函数到数据透视表的方法,从简单需求到复杂需求,采用不同的方法进行讲解,尤其是通过数据透视表的强大功能大…...

宝塔mysql数据库容量限制_宝塔数据库mysql-bin.000001占用磁盘空间过大

磁盘空间占用过多,排查后发现网站/www/wwwroot只占用7G,/www/server占用却高达8G,再深入排查发现/www/server/data目录下的mysql-bin.000001和mysql-bin.000002两个日志文件占去了1.5G空间。 百度后学到以下知识,做个记录。 mysql…...

LeetCode 2412.完成所有交易的初始最少钱数:【年度巨献】举例说明(讲明白),由难至简(手脚不乱),附Python一行版

【LetMeFly】2412.完成所有交易的初始最少钱数:【年度巨献】举例说明(讲明白),由难至简(手脚不乱),附Python一行版 文章目录 【LetMeFly】2412.完成所有交易的初始最少钱数:【年度巨献】举例说明(讲明白),由难至简(手脚…...

多人-多agent协同可能会挑战维纳的反馈

在多人-多Agent协同系统中,维纳的经典反馈机制将面临新的挑战,而协同过程中的“算计”(策略性决策与协调)成为实现高效协作的核心。 1、非线性与动态性 维纳的反馈理论(尤其是在控制理论中)通常假设系统的动…...

Go学习:类型转换需注意的点 以及 类型别名

目录 1. 类型转换 2. 类型别名 1. 类型转换 在从前的学习中,知道布尔bool类型变量只有两种值true或false,C/C、Python、JAVA等编程语言中,如果将布尔类型bool变量转换为整型int变量,通常采用 “0为假,非0为真”的方…...

C语言中的局部变量和全局变量有什么区别?

在C语言中,局部变量和全局变量是两种具有不同作用域和存储期的变量。以下是它们之间的主要区别: 作用域 局部变量: 局部变量是在函数内部声明的变量。它们的作用域仅限于声明它们的函数内部。一旦函数执行完毕,局部变量就会超出…...

价值交换到底在交换什么

有人十多岁就很清醒,知道自己想要什么,要付出什么。有人20多岁清醒了,有人30多岁都不一定明白。 价值交换,四个字其实就可以解释大部分事情。价值交换和努力工作,勤劳没有任何关系。甚至努力和成功都不存在关系。 价值…...

C++传送锚点的内存寻址:内存管理

文章目录 1.C/C内存分布回顾2.C内存管理2.1 内存申请2.2 operator new与operator delete函数2.3 定位new表达式 3.关于内存管理的常见知识点3.1 malloc/free和new/delete的区别3.2 内存泄漏 希望读者们多多三连支持小编会继续更新你们的鼓励就是我前进的动力! 继C语…...

Prompt提示词完整案例:让chatGPT成为“书单推荐”的高手

大家好,我是老六哥,我正在共享使用AI提高工作效率的技巧。欢迎关注我,共同提高使用AI的技能,让AI成功你的个人助理。 许多人可能会跟老六哥一样,有过这样的体验:当我们遇到一个能力出众或对事物有独到见解的…...

基于django的智能停车场车辆管理深度学习车牌识别系统

完整源码项目包获取→点击文章末尾名片!...

【Proteus仿真】【51单片机】简易计算器系统设计

目录 一、主要功能 二、使用步骤 三、硬件资源 四、软件设计 五、实验现象 联系作者 一、主要功能 1、LCD1602液晶显示 2、矩阵按键​ 3、可以进行简单的加减乘除运算 4、最大 9999*9999 二、使用步骤 系统运行后,LCD1602显示数据,通过矩阵按键…...

洛谷P3884 [JLOI2009] 二叉树问题(详解)c++

题目链接:P3884 [JLOI2009] 二叉树问题 - 洛谷 | 计算机科学教育新生态 1.题目解析 1:从8走向6的最短路径,向根节点就是向上走,从8到1会经过三条边,向叶节点就是向下走,从1走到6需要经过两条边&#xff0c…...

《Foundation 起步》

《Foundation 起步》 引言 Foundation 是一个广泛使用的开源前端框架,由 ZURB 团队创建。它旨在帮助开发者构建响应式、可访问性和移动优先的网页。本文将为您提供一个全面的指南,帮助您从零开始学习并使用 Foundation。 Foundation 简介 什么是 Foundation? Foundatio…...

React Native 导航系统实战(React Navigation)

导航系统实战(React Navigation) React Navigation 是 React Native 应用中最常用的导航库之一,它提供了多种导航模式,如堆栈导航(Stack Navigator)、标签导航(Tab Navigator)和抽屉…...

Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility

Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility 1. 实验室环境1.1 实验室环境1.2 小测试 2. The Endor System2.1 部署应用2.2 检查现有策略 3. Cilium 策略实体3.1 创建 allow-all 网络策略3.2 在 Hubble CLI 中验证网络策略源3.3 …...

CocosCreator 之 JavaScript/TypeScript和Java的相互交互

引擎版本: 3.8.1 语言: JavaScript/TypeScript、C、Java 环境:Window 参考:Java原生反射机制 您好,我是鹤九日! 回顾 在上篇文章中:CocosCreator Android项目接入UnityAds 广告SDK。 我们简单讲…...

Linux云原生安全:零信任架构与机密计算

Linux云原生安全:零信任架构与机密计算 构建坚不可摧的云原生防御体系 引言:云原生安全的范式革命 随着云原生技术的普及,安全边界正在从传统的网络边界向工作负载内部转移。Gartner预测,到2025年,零信任架构将成为超…...

让AI看见世界:MCP协议与服务器的工作原理

让AI看见世界:MCP协议与服务器的工作原理 MCP(Model Context Protocol)是一种创新的通信协议,旨在让大型语言模型能够安全、高效地与外部资源进行交互。在AI技术快速发展的今天,MCP正成为连接AI与现实世界的重要桥梁。…...

根据万维钢·精英日课6的内容,使用AI(2025)可以参考以下方法:

根据万维钢精英日课6的内容,使用AI(2025)可以参考以下方法: 四个洞见 模型已经比人聪明:以ChatGPT o3为代表的AI非常强大,能运用高级理论解释道理、引用最新学术论文,生成对顶尖科学家都有用的…...

Rapidio门铃消息FIFO溢出机制

关于RapidIO门铃消息FIFO的溢出机制及其与中断抖动的关系,以下是深入解析: 门铃FIFO溢出的本质 在RapidIO系统中,门铃消息FIFO是硬件控制器内部的缓冲区,用于临时存储接收到的门铃消息(Doorbell Message)。…...

项目部署到Linux上时遇到的错误(Redis,MySQL,无法正确连接,地址占用问题)

Redis无法正确连接 在运行jar包时出现了这样的错误 查询得知问题核心在于Redis连接失败,具体原因是客户端发送了密码认证请求,但Redis服务器未设置密码 1.为Redis设置密码(匹配客户端配置) 步骤: 1).修…...

laravel8+vue3.0+element-plus搭建方法

创建 laravel8 项目 composer create-project --prefer-dist laravel/laravel laravel8 8.* 安装 laravel/ui composer require laravel/ui 修改 package.json 文件 "devDependencies": {"vue/compiler-sfc": "^3.0.7","axios": …...

高效线程安全的单例模式:Python 中的懒加载与自定义初始化参数

高效线程安全的单例模式:Python 中的懒加载与自定义初始化参数 在软件开发中,单例模式(Singleton Pattern)是一种常见的设计模式,确保一个类仅有一个实例,并提供一个全局访问点。在多线程环境下,实现单例模式时需要注意线程安全问题,以防止多个线程同时创建实例,导致…...