当前位置: 首页 > news >正文

自定义数据集,使用 PyTorch 框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测

在本文中,我们将展示如何使用 NumPy 创建自定义数据集,利用 PyTorch 实现一个简单的逻辑回归模型,并在训练完成后保存该模型,最后加载模型并用它进行预测。

1. 创建自定义数据集

首先,我们使用 NumPy 创建一个简单的二分类数据集。假设我们的数据集包含两个特征。

import numpy as np# 生成随机数据
np.random.seed(42)
X = np.random.randn(100, 2)  # 100个样本,2个特征
y = (X[:, 0] + X[:, 1] > 0).astype(int)  # 标签为1或0# 打印数据集的前5个样本
print(X[:5], y[:5])

2. 构建逻辑回归模型

接下来,我们使用 PyTorch 来构建一个简单的逻辑回归模型。PyTorch 提供了 torch.nn.Module 类,能够轻松实现神经网络模型。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset# 转换 NumPy 数据为 PyTorch 张量
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32).view(-1, 1)# 创建数据集并加载
dataset = TensorDataset(X_tensor, y_tensor)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)# 定义逻辑回归模型
class LogisticRegressionModel(nn.Module):def __init__(self):super(LogisticRegressionModel, self).__init__()self.linear = nn.Linear(2, 1)  # 2个输入特征,1个输出def forward(self, x):return torch.sigmoid(self.linear(x))# 初始化模型
model = LogisticRegressionModel()

3. 训练模型

接下来,我们定义损失函数和优化器,并训练模型。

# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练模型
epochs = 1000
for epoch in range(epochs):for inputs, labels in dataloader:optimizer.zero_grad()  # 清空梯度outputs = model(inputs)  # 前向传播loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新权重if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

4. 保存模型

模型训练完成后,我们可以将模型保存到文件中。PyTorch 提供了 torch.save() 方法来保存模型的状态字典。

# 保存模型
torch.save(model.state_dict(), 'logistic_regression_model.pth')
print("模型已保存!")

5. 加载模型并进行预测

我们可以加载保存的模型,并对新数据进行预测。

# 加载模型
loaded_model = LogisticRegressionModel()
loaded_model.load_state_dict(torch.load('logistic_regression_model.pth'))
loaded_model.eval()  # 切换到评估模式# 进行预测
with torch.no_grad():test_data = torch.tensor([[1.5, -0.5]], dtype=torch.float32)prediction = loaded_model(test_data)print(f'预测值: {prediction.item():.4f}')

6. 总结

在这篇博客中,我们展示了如何使用 NumPy 创建一个简单的自定义数据集,并使用 PyTorch 实现一个逻辑回归模型。我们还展示了如何保存训练好的模型,并加载模型进行预测。通过保存和加载模型,我们可以在不同的时间或环境中重复使用已经训练好的模型,而不需要重新训练它。

相关文章:

自定义数据集,使用 PyTorch 框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测

在本文中,我们将展示如何使用 NumPy 创建自定义数据集,利用 PyTorch 实现一个简单的逻辑回归模型,并在训练完成后保存该模型,最后加载模型并用它进行预测。 1. 创建自定义数据集 首先,我们使用 NumPy 创建一个简单的…...

UE5 特效

能帮到你的话,就给个赞吧 😘 文章目录 post processexposurebloomvignettesaturationunbound material材质蓝图alt z base colorconstant3Vector roughnessconstant metallicconstant pbrroughnessmetallicmake more realmake some areas rougher than o…...

CMAKE工程编译好后自动把可执行文件传输到远程开发板

# 设置 CMake 最低版本要求 cmake_minimum_required(VERSION 3.10)# 设置项目名称 project(MyProject)# 添加可执行文件,这里以项目名作为可执行文件的名称 add_executable(${PROJECT_NAME} main.cpp)# 设置开发板信息 set(DEVELOPMENT_BOARD_IP "192.168.1.10…...

Windows 程序设计7:文件的创建、打开与关闭

文章目录 前言一、文件的创建与打开CreateFile1. 创建新的空白文件2. 打开已存在文件3. 打开一个文件时,如果文件存在则打开,如果文件不存在则新创建文件4.打开一个文件,如果文件存在则打开文件并清空内容,文件不存在则 新创建文件…...

策略模式 - 策略模式的使用

引言 在软件开发中,设计模式是解决常见问题的经典解决方案。策略模式(Strategy Pattern)是行为型设计模式之一,它允许在运行时选择算法的行为。通过将算法封装在独立的类中,策略模式使得算法可以独立于使用它的客户端…...

具身智能研究报告

参考: (1)GTC大会&Figure:“具身智能”奇点已至 (2)2024中国具身智能创投报告 (3)2024年具身智能产业发展研究报告 (4)具身智能行业深度:发展…...

Windows安装Milvus

安装Milvus 安装Docker前置条件: 安装Mlivus方案一方案二 Attu管理端 安装Docker 系统:Windows 11 家庭中文版 Mlivus:V2.3.0 Attu: V2.3.10 前置条件: 启用“适用于 Linux 的 Windows 子系统”可选功能,才能在 Win…...

Excel分区间统计分析(等步长、不等步长、多维度)

在数据分析过程中,可能会需要统计不同数据区间的人数、某个数据区间的平均值或者进行分组区间统计,本文从excel函数到数据透视表的方法,从简单需求到复杂需求,采用不同的方法进行讲解,尤其是通过数据透视表的强大功能大…...

宝塔mysql数据库容量限制_宝塔数据库mysql-bin.000001占用磁盘空间过大

磁盘空间占用过多,排查后发现网站/www/wwwroot只占用7G,/www/server占用却高达8G,再深入排查发现/www/server/data目录下的mysql-bin.000001和mysql-bin.000002两个日志文件占去了1.5G空间。 百度后学到以下知识,做个记录。 mysql…...

LeetCode 2412.完成所有交易的初始最少钱数:【年度巨献】举例说明(讲明白),由难至简(手脚不乱),附Python一行版

【LetMeFly】2412.完成所有交易的初始最少钱数:【年度巨献】举例说明(讲明白),由难至简(手脚不乱),附Python一行版 文章目录 【LetMeFly】2412.完成所有交易的初始最少钱数:【年度巨献】举例说明(讲明白),由难至简(手脚…...

多人-多agent协同可能会挑战维纳的反馈

在多人-多Agent协同系统中,维纳的经典反馈机制将面临新的挑战,而协同过程中的“算计”(策略性决策与协调)成为实现高效协作的核心。 1、非线性与动态性 维纳的反馈理论(尤其是在控制理论中)通常假设系统的动…...

Go学习:类型转换需注意的点 以及 类型别名

目录 1. 类型转换 2. 类型别名 1. 类型转换 在从前的学习中,知道布尔bool类型变量只有两种值true或false,C/C、Python、JAVA等编程语言中,如果将布尔类型bool变量转换为整型int变量,通常采用 “0为假,非0为真”的方…...

C语言中的局部变量和全局变量有什么区别?

在C语言中,局部变量和全局变量是两种具有不同作用域和存储期的变量。以下是它们之间的主要区别: 作用域 局部变量: 局部变量是在函数内部声明的变量。它们的作用域仅限于声明它们的函数内部。一旦函数执行完毕,局部变量就会超出…...

价值交换到底在交换什么

有人十多岁就很清醒,知道自己想要什么,要付出什么。有人20多岁清醒了,有人30多岁都不一定明白。 价值交换,四个字其实就可以解释大部分事情。价值交换和努力工作,勤劳没有任何关系。甚至努力和成功都不存在关系。 价值…...

C++传送锚点的内存寻址:内存管理

文章目录 1.C/C内存分布回顾2.C内存管理2.1 内存申请2.2 operator new与operator delete函数2.3 定位new表达式 3.关于内存管理的常见知识点3.1 malloc/free和new/delete的区别3.2 内存泄漏 希望读者们多多三连支持小编会继续更新你们的鼓励就是我前进的动力! 继C语…...

Prompt提示词完整案例:让chatGPT成为“书单推荐”的高手

大家好,我是老六哥,我正在共享使用AI提高工作效率的技巧。欢迎关注我,共同提高使用AI的技能,让AI成功你的个人助理。 许多人可能会跟老六哥一样,有过这样的体验:当我们遇到一个能力出众或对事物有独到见解的…...

基于django的智能停车场车辆管理深度学习车牌识别系统

完整源码项目包获取→点击文章末尾名片!...

【Proteus仿真】【51单片机】简易计算器系统设计

目录 一、主要功能 二、使用步骤 三、硬件资源 四、软件设计 五、实验现象 联系作者 一、主要功能 1、LCD1602液晶显示 2、矩阵按键​ 3、可以进行简单的加减乘除运算 4、最大 9999*9999 二、使用步骤 系统运行后,LCD1602显示数据,通过矩阵按键…...

洛谷P3884 [JLOI2009] 二叉树问题(详解)c++

题目链接:P3884 [JLOI2009] 二叉树问题 - 洛谷 | 计算机科学教育新生态 1.题目解析 1:从8走向6的最短路径,向根节点就是向上走,从8到1会经过三条边,向叶节点就是向下走,从1走到6需要经过两条边&#xff0c…...

《Foundation 起步》

《Foundation 起步》 引言 Foundation 是一个广泛使用的开源前端框架,由 ZURB 团队创建。它旨在帮助开发者构建响应式、可访问性和移动优先的网页。本文将为您提供一个全面的指南,帮助您从零开始学习并使用 Foundation。 Foundation 简介 什么是 Foundation? Foundatio…...

Duix-Avatar全离线数字人创作平台深度指南:从部署到高级应用

Duix-Avatar全离线数字人创作平台深度指南:从部署到高级应用 【免费下载链接】Duix-Avatar 项目地址: https://gitcode.com/GitHub_Trending/he/Duix-Avatar 价值解析:Duix-Avatar的SWOT战略分析 优势(Strengths) 全栈本地化架构:所…...

CCMusic跨平台部署指南:Windows/Linux/macOS全适配

CCMusic跨平台部署指南:Windows/Linux/macOS全适配 音乐风格识别从未如此简单——无论你用哪种电脑系统 1. 开篇:为什么需要跨平台部署方案 还在为音乐风格分类工具的安装头疼吗?不同的操作系统、不同的环境配置、复杂的依赖关系...这些麻烦…...

PathOfBuilding:流放之路玩家的离线构建神器,打造最强角色规划方案

PathOfBuilding:流放之路玩家的离线构建神器,打造最强角色规划方案 【免费下载链接】PathOfBuilding Offline build planner for Path of Exile. 项目地址: https://gitcode.com/GitHub_Trending/pa/PathOfBuilding 你是否曾经在《流放之路》中花…...

ssm+java2026年毕设随心淘网管理系统【源码+论文】

本系统(程序源码)带文档lw万字以上 文末可获取一份本项目的java源码和数据库参考。系统程序文件列表开题报告内容一、选题背景关于电商会员管理系统的研究,现有研究主要以大型综合电商平台(如淘宝、京东)的整体架构设计…...

跨设备滚动优化:Scroll Reverser让macOS操作效率提升80%的效率工具

跨设备滚动优化:Scroll Reverser让macOS操作效率提升80%的效率工具 【免费下载链接】Scroll-Reverser Per-device scrolling prefs on macOS. 项目地址: https://gitcode.com/gh_mirrors/sc/Scroll-Reverser 在当今多设备办公环境中,Mac用户常常面…...

用Python脚本让Crazyflie 2.X无人机动起来:手把手教你写第一个自主飞行程序

用Python脚本让Crazyflie 2.X无人机动起来:从零编写自主飞行程序 当第一次看到Crazyflie这个巴掌大的无人机在桌面上悬停时,我意识到微小型飞行器的编程控制远比想象中更有趣。与传统无人机不同,Crazyflie 2.X系列通过Python脚本就能实现毫米…...

Python中数据映射与转换的实现方法

在Python编程中,数据映射与转换是数据处理过程中的核心环节,广泛应用于数据清洗、格式转换、特征工程等多个领域。本文将系统梳理Python中实现数据映射与转换的多种方法,涵盖基础技巧、进阶应用及第三方库的高效实现,帮助开发者构…...

平面六杆机构的运动仿真(毕业论文+CAD图纸+开题报告+外文翻译)

平面六杆机构作为机械传动领域的重要构件,其运动特性直接影响机械系统的整体性能。该机构由六个刚性杆件通过转动副或移动副连接形成闭合环路,通过调整杆长比例与铰链位置,可实现复杂轨迹输出与多自由度运动控制。相较于四杆机构,…...

Windows 11下xray安装全流程:从下载到配置证书的保姆级教程

Windows 11安全工具配置全指南:从零开始搭建本地测试环境 在数字化生活日益普及的今天,个人电脑安全越来越受到重视。对于技术爱好者而言,了解和使用专业安全工具不仅能提升自身防护能力,也是学习网络安全知识的重要途径。本文将详…...

从零封装Vue版JSMpeg播放器:支持截图/录制/旋转的直播流组件开发指南

从零封装Vue版JSMpeg播放器:支持截图/录制/旋转的直播流组件开发指南 1. 技术选型与架构设计 在Web端实现低延迟视频直播需要解决三个核心问题:编解码效率、传输协议选择和渲染性能。基于JSMpeg的方案优势在于: 超低延迟(可达50ms…...