当前位置: 首页 > article >正文

自定义数据集 使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测

代码:

import torch
import numpy as np
import torch.nn as nn# 定义数据:x_data 是特征,y_data 是标签(目标值)
data = [[-0.5, 7.7],[1.8, 98.5],[0.9, 57.8],[0.4, 39.2],[-1.4, -15.7],[-1.4, -37.3],[-1.8, -49.1],[1.5, 75.6],[0.4, 34.0],[0.8, 62.3]]# 将数据转为 numpy 数组
data = np.array(data)# 提取 x_data 和 y_data
x_data = data[:, 0]  # 取第一列作为输入特征
y_data = data[:, 1]  # 取第二列作为目标标签# 将数据转换为 PyTorch 张量
x_train = torch.tensor(x_data, dtype=torch.float32)  # 输入特征
y_train = torch.tensor(y_data, dtype=torch.float32)  # 目标标签# 使用 TensorDataset 来创建一个数据集
from torch.utils.data import DataLoader, TensorDatasetdataset = TensorDataset(x_train, y_train)  # 使用训练数据创建数据集
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)  # 将数据集转换为 DataLoader,批大小为 2,且每个 epoch 都会随机打乱数据# 定义损失函数:均方误差损失 (MSELoss)
criterion = nn.MSELoss()# 定义线性回归模型
class LinearModel(nn.Module):def __init__(self):super(LinearModel, self).__init__()# 使用一个线性层,输入为1维,输出为1维self.layers = nn.Linear(1, 1)def forward(self, x):# 直接返回线性层的输出return self.layers(x)# 初始化模型
model = LinearModel()# 定义优化器:使用随机梯度下降 (SGD),学习率为 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)# 设置训练的 epoch 数量
epochs = 500# 开始训练
for n in range(1, epochs + 1):epoch_loss = 0  # 初始化每个 epoch 的总损失# 遍历每一个 batchfor batch_x, batch_y in dataloader:# 对每个输入 batch_x 添加一个维度,使其适应线性回归模型y_pred = model(batch_x.unsqueeze(1))# 计算当前 batch 的损失batch_loss = criterion(y_pred.squeeze(1), batch_y)# 清除梯度optimizer.zero_grad()# 反向传播batch_loss.backward()# 更新模型参数optimizer.step()# 累加损失epoch_loss += batch_loss# 计算当前 epoch 的平均损失avg_loss = epoch_loss / len(dataloader)# 每 10 个 epoch 或第一个 epoch 打印一次损失if n % 10 == 0 or n == 1:print(f"epoches:{n}, loss:{avg_loss:.4f}")# 在每 10 个 epoch 保存一次模型的状态字典torch.save(model.state_dict(), 'model.pth')# 训练完成后,加载保存的模型
model.load_state_dict(torch.load("model.pth"))# 将模型切换到评估模式
model.eval()# 测试数据
x_test = torch.tensor([[1.8]], dtype=torch.float32)  # 测试输入 1.8# 使用模型进行预测
with torch.no_grad():  # 在预测时不计算梯度y_pre = model(x_test)# 打印预测结果
print(y_pre)

结果:

相关文章:

自定义数据集 使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测

代码: import torch import numpy as np import torch.nn as nn# 定义数据:x_data 是特征,y_data 是标签(目标值) data [[-0.5, 7.7],[1.8, 98.5],[0.9, 57.8],[0.4, 39.2],[-1.4, -15.7],[-1.4, -37.3],[-1.8, -49.…...

【MQ】如何保证消息队列的高可用?

RocketMQ NameServer集群部署 Broker做了集群部署 主从模式 类型:同步复制、异步复制 主节点返回消息给客户端的时候是否需要同步从节点 Dledger:要求至少消息复制到半数以上的节点之后,才给客户端返回写入成功 slave定时从master同步数据…...

本地大模型编程实战(04)给文本自动打标签

文章目录 准备实例化本地大模型情感分析更精细的控制总结代码 使用本地大模型可以根据需要给文本打标签,本文介绍了如何基于 langchain 和本地部署的大模型给文本打标签。 本文使用 llama3.1 作为本地大模型,它的性能比非开源大模型要查一下,…...

关于使用PHP时WordPress排错——“这意味着您在wp-config.php文件中指定的用户名和密码信息不正确”的解决办法

本来是看到一位好友的自己建站,所以突发奇想,在本地装个WordPress玩玩吧,就尝试着装了一下,因为之前电脑上就有MySQL,所以在自己使用PHP建立MySQL时报错了。 最开始是我的php启动mysql时有问题,也就是启动过…...

【蓝桥杯】43694.正则问题

题目描述 考虑一种简单的正则表达式: 只由 x ( ) | 组成的正则表达式。 小明想求出这个正则表达式能接受的最长字符串的长度。 例如 ((xx|xxx)x|(x|xx))xx 能接受的最长字符串是: xxxxxx,长度是 6。 输入描述 一个由 x()| 组成的正则表达式。…...

服务器虚拟化技术详解与实战:架构、部署与优化

📝个人主页🌹:一ge科研小菜鸡-CSDN博客 🌹🌹期待您的关注 🌹🌹 引言 在现代 IT 基础架构中,服务器虚拟化已成为提高资源利用率、降低运维成本、提升系统灵活性的重要手段。通过服务…...

git困扰的问题

.gitignore中添加的某个忽略文件并不生效 把某些目录或文件加入忽略规则,按照上述方法定义后发现并未生效, gitignore只能忽略那些原来没有被追踪的文件,如果某些文件已经被纳入了版本管理中,则修改.gitignore是无效的。 解决方…...

jvm--类的生命周期

学习类的生命周期之前,需要了解一下jvm的几个重要的内存区域: (1)方法区:存放已经加载的类信息、常量、静态变量以及方法代码的内存区域 (2)常量池:常量池是方法区的一部分&#x…...

定制Centos镜像(一)

环境准备: 一台最小化安装的干净的系统,这里使用Centos7.9,一个Centos镜像,镜像也使用Centos7.9的。 [rootlocalhost ~]# cat /etc/system-release CentOS Linux release 7.9.2009 (Core) [rootlocalhost ~]# rpm -qa | wc -l 306 [rootloca…...

C语言------数组思维导图

...

TensorFlow实现逻辑回归模型

逻辑回归是一种经典的分类算法,广泛应用于二分类问题。本文将介绍如何使用TensorFlow框架实现逻辑回归模型,并通过动态绘制决策边界和损失曲线来直观地观察模型的训练过程。 数据准备 首先,我们准备两类数据点,分别表示两个不同…...

《十七》浏览器基础

浏览器:是安装在电脑里面的一个软件,能够将页面内容渲染出来呈现给用户查看,并让用户与网页进行交互。 常见的主流浏览器: 常见的主流浏览器有:Chrome、Safari、Firefox、Opera、Edge 等。 输入 URL,浏览…...

Windows 靶机常见服务、端口及枚举工具与方法全解析:SMB、LDAP、NFS、RDP、WinRM、DNS

在渗透测试中,Windows 靶机通常会运行多种服务,每种服务都有其默认端口和常见的枚举工具及方法。以下是 Windows 靶机常见的服务、端口、枚举工具和方法的详细说明: 1. SMB(Server Message Block) 端口 445/TCP&…...

IME关于输入法横屏全屏显示问题-Android14

IME关于输入法横屏全屏显示问题-Android14 1、输入法全屏模式updateFullscreenMode1.1 全屏模式判断1.2 全屏模式布局设置 2、应用侧关闭输入法全屏模式2.1 调用输入法的应用设置flag2.2 继承InputMethodService.java的输入法应用覆盖onEvaluateFullscreenMode方法 InputMethod…...

网络安全 | F5-Attack Signatures-Set详解

关注:CodingTechWork 创建和分配攻击签名集 可以通过两种方式创建攻击签名集:使用过滤器或手动选择要包含的签名。  基于过滤器的签名集仅基于在签名过滤器中定义的标准。基于过滤器的签名集的优点在于,可以专注于定义用户感兴趣的攻击签名…...

STranslate 中文绿色版即时翻译/ OCR 工具 v1.3.1.120

STranslate 是一款功能强大且用户友好的翻译工具,它支持多种语言的即时翻译,提供丰富的翻译功能和便捷的使用体验。STranslate 特别适合需要频繁进行多语言交流的个人用户、商务人士和翻译工作者。 软件功能 1. 即时翻译: 文本翻译&#xff…...

基于微信小程序的助农扶贫系统设计与实现(LW+源码+讲解)

专注于大学生项目实战开发,讲解,毕业答疑辅导,欢迎高校老师/同行前辈交流合作✌。 技术范围:SpringBoot、Vue、SSM、HLMT、小程序、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容:…...

我谈区域偏心率

偏心率的数学定义 禹晶、肖创柏、廖庆敏《数字图像处理(面向新工科的电工电子信息基础课程系列教材)》P312 区域的拟合椭圆看这里。 Rafael Gonzalez的二阶中心矩的表达不说人话。 我认为半长轴和半短轴不等于特征值,而是特征值的根号。…...

关于低代码技术架构的思考

我们经常会看到很多低代码系统的技术架构图,而且经常看不懂。是因为技术架构图没有画好,还是因为技术不够先进,有时候往往都不是。 比如下图: 一个开发者,看到的视角往往都是技术层面,你给用户讲React18、M…...

若依路由配置教程

1. 路由配置文件 2. 配置内容介绍 { path: "/tool/gen-edit", component: Layout, //在路由下,引用组件的名称,在页面中包括这个组件的内容(页面框架内容) hidden: true, //此页面的内容,在左边的菜单中不用显示。 …...

基于ESP8266的多功能环境监测与反馈系统开发指南

项目概述 本系统集成了物联网开发板、高精度时钟模块、环境传感器和可视化显示模块,构建了一个智能环境监测与反馈装置。通过ESP8266 NodeMCU作为核心控制器,结合DS3231实时时钟、DHT11温湿度传感器、光敏电阻和OLED显示屏,实现了环境参数的…...

【Elasticsearch】doc_values 可以用于查询操作

确实,doc values 可以用于查询操作,尽管它们的主要用途是支持排序、聚合和脚本中的字段访问。在某些情况下,Elasticsearch 也会利用 doc values 来执行特定类型的查询。以下是关于 doc values 在查询操作中的使用及其影响的详细解释&#xff…...

HTML5 Web Worker 的使用与实践

引言 在现代 Web 开发中,用户体验是至关重要的。如果页面在执行复杂计算或处理大量数据时变得卡顿或无响应,用户很可能会流失。HTML5 引入了 Web Worker,它允许我们在后台运行 JavaScript 代码,从而避免阻塞主线程,保…...

flutter_学习记录_00_环境搭建

1.参考文档 Mac端Flutter的环境配置看这一篇就够了 flutter的中文官方文档 2. 本人环境搭建的背景 本人的电脑的是Mac的,iOS开发,所以iOS开发环境本身是可用的;外加Mac电脑本身就会配置Java的环境。所以,后面剩下的就是&#x…...

自助设备系统设置——对接POS支付

输入管理员密码 一、录入POS网关信息 填写网关信息后保存,重新启动软件...

Calibre(阅读转换)-官方开源中文版[完整的电子图书馆系统,包括图书馆管理,格式转换,新闻,材料转换为电子书]

Calibre(阅读&转换)-官方开源中文版 链接:https://pan.xunlei.com/s/VOHbKYUwd3ASVXTi2Ok1vkK3A1?pwd92ny#...

2748. 美丽下标对的数目(Beautiful Pairs)

2748. 美丽下标对的数目&#xff08;Beautiful Pairs&#xff09; 题目分析 给定一个整数数组 nums&#xff0c;我们需要找出其中符合条件的“美丽下标对”。美丽下标对是指&#xff0c;数组中的某一对数字 nums[i] 和 nums[j]&#xff08;其中 0 ≤ i < j < nums.leng…...

【unity游戏开发之InputSystem——06】PlayerInputManager组件实现本地多屏的游戏(基于unity6开发介绍)

文章目录 PlayerInputManager 简介1、PlayerInputManager 的作用2、主要功能一、PlayerInputManager组件参数1、Notification Behavior 通知行为2、Join Behavior:玩家加入的行为3、Player Prefab 玩家预制件4、Joining Enabled By Default 默认启用加入5、Limit Number Of Pl…...

算法刷题Day29:BM67 不同路径的数目(一)

题目链接 描述 解题思路&#xff1a; 二维dp数组初始化。 dp[i][0] 1, dp[0][j] 1 。因为到达第一行第一列的每个格子只能有一条路。状态转移 dp[i][j] dp[i-1][j] dp[i][j-1] 代码&#xff1a; class Solution: def uniquePaths(self , m: int, n: int) -> int: #…...

c语言网 1130数字母

原题 题目描述 输入一个字符串,数出其中的字母的个数. 输入格式 一个字符串&#xff0c;不包含空格&#xff08;长度小于100&#xff09; 输出格式 字符串中的字母的个数 样例输入 复制 124lfdk54AIEJ92854&%$GJ 样例输出 复制 10 #include<iostream> #include<cc…...