使用 PyTorch 实现逻辑回归并评估模型性能
1. 逻辑回归简介
逻辑回归是一种用于解决二分类问题的算法。它通过一个逻辑函数(Sigmoid 函数)将线性回归的输出映射到 [0, 1] 区间内,从而将问题转化为概率预测问题。如果预测概率大于 0.5,则将样本分类为正类;否则分类为负类。
2. 数据准备
为了演示逻辑回归的效果,我们构造了一个简单的二维数据集,包含两类样本。每类样本有 7 个数据点,特征维度为 2。
class1_points = np.array([[1.9, 1.2],[1.5, 2.1],[1.9, 0.5],[1.5, 0.9],[0.9, 1.2],[1.1, 1.7],[1.4, 1.1]])class2_points = np.array([[3.2, 3.2],[3.7, 2.9],[3.2, 2.6],[1.7, 3.3],[3.4, 2.6],[4.1, 2.3],[3.0, 2.9]])
我们将这两类数据点的特征合并,并为每个数据点分配标签(0 表示第一类,1 表示第二类)。
3. 模型构建
我们使用 PyTorch 框架来实现逻辑回归模型。模型结构非常简单,仅包含一个线性层和一个 Sigmoid 激活函数。
class LogisticRegression(nn.Module):def __init__(self):super(LogisticRegression, self).__init__()self.linear = nn.Linear(2, 1) # 输入特征维度为 2,输出为 1def forward(self, x):return torch.sigmoid(self.linear(x))
4. 模型训练
我们使用二分类交叉熵损失函数(BCELoss)和随机梯度下降优化器(SGD)来训练模型。训练过程如下:
epochs = 5000
for epoch in range(epochs):model.train()optimizer.zero_grad()outputs = model(X)loss = criterion(outputs, y)loss.backward()optimizer.step()if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}')
训练过程中,我们每 100 个 epoch 打印一次损失值,以便观察模型的收敛情况。
5. 模型保存与加载
训练完成后,我们将模型的权重保存到文件中,方便后续加载和使用。
torch.save(model.state_dict(), 'model3.pth')
print("模型已保存")
加载模型时,我们创建一个新的模型实例,并使用 load_state_dict 方法加载保存的权重。
loaded_model = LogisticRegression()
loaded_model.load_state_dict(torch.load('model3.pth', map_location=torch.device('cpu')))
loaded_model.eval()
6. 模型预测与性能评估
加载模型后,我们使用模型对训练数据进行预测,并计算精确度、召回率和 F1 分数。
with torch.no_grad():predictions = loaded_model(X)predicted_labels = (predictions > 0.5).float()print("实际结果:", y.numpy().flatten())
print("预测结果:", predicted_labels.numpy().flatten())precision = precision_score(y.numpy().flatten(), predicted_labels.numpy().flatten())
recall = recall_score(y.numpy().flatten(), predicted_labels.numpy().flatten())
f1 = f1_score(y.numpy().flatten(), predicted_labels.numpy().flatten())print(f"精确度(Precision): {precision:.4f}")
print(f"召回率(Recall): {recall:.4f}")
print(f"F1 分数: {f1:.4f}")
7. 运行结果
运行上述代码后,我们得到了以下结果:
-
实际结果:[0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1.]
-
预测结果:[0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1.]
-
精确度(Precision):1.0000
-
召回率(Recall):1.0000
-
F1 分数:1.0000
从结果可以看出,模型在训练集上表现良好,精确度、召回率和 F1 分数均为 1.0000。
8. 完整代码
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import precision_score, recall_score, f1_score"""使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测,对预测结果计算精确度和召回率及F1分数"""
# 提取特征和标签
class1_points = np.array([[1.9, 1.2],[1.5, 2.1],[1.9, 0.5],[1.5, 0.9],[0.9, 1.2],[1.1, 1.7],[1.4, 1.1]])class2_points = np.array([[3.2, 3.2],[3.7, 2.9],[3.2, 2.6],[1.7, 3.3],[3.4, 2.6],[4.1, 2.3],[3.0, 2.9]])# 提取两类特征,输入特征维度为2
x1_data = np.concatenate((class1_points[:, 0], class2_points[:, 0]), axis=0)
x2_data = np.concatenate((class1_points[:, 1], class2_points[:, 1]), axis=0)
label = np.concatenate((np.zeros(len(class1_points)), np.ones(len(class2_points))), axis=0)# 将数据转换为 PyTorch 张量
X = torch.tensor(np.column_stack((x1_data, x2_data)), dtype=torch.float32)
y = torch.tensor(label, dtype=torch.float32).view(-1, 1)# 定义逻辑回归模型
class LogisticRegression(nn.Module):def __init__(self):super(LogisticRegression, self).__init__()self.linear = nn.Linear(2, 1) # 输入特征维度为 2,输出为 1def forward(self, x):return torch.sigmoid(self.linear(x))# 初始化模型、损失函数和优化器
model = LogisticRegression()
criterion = nn.BCELoss() # 二分类交叉熵损失
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练模型
epochs = 5000
for epoch in range(epochs):model.train()optimizer.zero_grad()outputs = model(X)loss = criterion(outputs, y)loss.backward()optimizer.step()if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}')# 保存模型
torch.save(model.state_dict(), 'model3.pth')
print("模型已保存")# 加载模型
loaded_model = LogisticRegression()
loaded_model.load_state_dict(torch.load('model3.pth',map_location=torch.device('cpu'),weights_only=True))
loaded_model.eval()# 进行预测
with torch.no_grad():predictions = loaded_model(X)predicted_labels = (predictions > 0.5).float()# 展示预测结果和实际结果
print("实际结果:", y.numpy().flatten())
print("预测结果:", predicted_labels.numpy().flatten())# 计算精确度、召回率和 F1 分数
precision = precision_score(y.numpy().flatten(), predicted_labels.numpy().flatten())
recall = recall_score(y.numpy().flatten(), predicted_labels.numpy().flatten())
f1 = f1_score(y.numpy().flatten(), predicted_labels.numpy().flatten())print(f"精确度(Precision): {precision:.4f}")
print(f"召回率(Recall): {recall:.4f}")
print(f"F1 分数: {f1:.4f}")
相关文章:
使用 PyTorch 实现逻辑回归并评估模型性能
1. 逻辑回归简介 逻辑回归是一种用于解决二分类问题的算法。它通过一个逻辑函数(Sigmoid 函数)将线性回归的输出映射到 [0, 1] 区间内,从而将问题转化为概率预测问题。如果预测概率大于 0.5,则将样本分类为正类;否则分…...
python学opencv|读取图像(五十二)使用cv.matchTemplate()函数实现最佳图像匹配
【1】引言 前序学习了图像的常规读取和基本按位操作技巧,相关文章包括且不限于: python学opencv|读取图像-CSDN博客 python学opencv|读取图像(四十九)原理探究:使用cv2.bitwise()系列函数实现图像按位运算-CSDN博客…...
【VUE案例练习】前端vue2+element-ui,后端nodo+express实现‘‘文件上传/删除‘‘功能
近期在做跟毕业设计相关的数据后台管理系统,其中的列表项展示有图片展示,添加/编辑功能有文件上传。 “文件上传/删除”也是我们平时开发会遇到的一个功能,这里分享个人的实现过程,与大家交流谈论~ 一、准备工作 本次案例使用的…...
使用真实 Elasticsearch 进行高级集成测试
作者:来自 Elastic Piotr Przybyl 掌握高级 Elasticsearch 集成测试:更快、更智能、更优化。 在上一篇关于集成测试的文章中,我们介绍了如何通过改变数据初始化策略来缩短依赖于真实 Elasticsearch 的集成测试的执行时间。在本期中࿰…...
【R语言】函数
一、函数格式 如下所示: hello:函数名;function:定义的R对象是函数而不是其它变量;():函数的输入参数,可以为空,也可以包含参数;{}:函数体,如果…...
Vue 3 30天精进之旅:Day 12 - 异步操作
在现代前端开发中,异步操作是一个非常常见的需求,例如从后端API获取数据、进行文件上传等任务。Vue 3 结合组合式API和Vuex可以方便地处理这些异步操作。今天我们将重点学习如何在Vue应用中进行异步操作,包括以下几个主题: 异步操…...
VSCode插件Live Server
简介:插件Live Server能够实现当我们在VSCode编辑器里修改 HTML、CSS 或者 JavaScript 文件时,它都能自动实时地刷新浏览器页面,让我们实时看到代码变化的效果。再也不用手动刷新浏览器了,节省了大量的开发过程耗时! 1…...
50. 正点原子官方系统镜像烧写实验
一、Windows下使用OTG烧写系统 1、在Windos使用NXP提供的mfgtool来向开发烧写系统。需要用先将开发板的USB_OTG接口连接到电脑上。 Mfgtool工具是向板子先下载一个Linux系统,然后通过这个系统来完成烧写工作。 切记!使用OTG烧写的时候要先把SD卡拔出来&…...
在C#中,什么是多态如何实现
在C#中,什么是多态?如何实现? C#中的多态性 多态性是面向对象编程的一个核心概念,他允许对象以多种形式表现.在C#中,多态主要通过虚方法,抽象方法和接口来实现. 多态性的存在使得同一个行为可以有多个不同的表达形式 即同一个接口可以使用不同的实例来执行不同的操作 虚方…...
搜索引擎友好:设计快速收录的网站架构
本文来自:百万收录网 原文链接:https://www.baiwanshoulu.com/14.html 为了设计一个搜索引擎友好的网站架构,以实现快速收录,可以从以下几个方面入手: 一、清晰的目录结构与层级 合理划分内容:目录结构应…...
扩散模型(三)
相关阅读: 扩散模型(一) 扩散模型(二) Latent Variable Space 潜在扩散模型(LDM;龙巴赫、布拉特曼等人,2022 年)在潜在空间而非像素空间中运行扩散过程,这…...
it基础使用--5---git远程仓库
it基础使用–5—git远程仓库 1. 按顺序看 -git基础使用–1–版本控制的基本概念 -git基础使用–2–gti的基本概念 -git基础使用–3—安装和基本使用 -git基础使用–4—git分支和使用 2. 什么是远程仓库 在第一篇文章中,我们已经讲过了远程仓库,每个本…...
Baklib如何改变内容管理平台的未来推动创新与效率提升
内容概要 在信息爆炸的时代,内容管理平台成为了企业和个人不可或缺的工具。它通过高效组织、存储和发布内容,帮助用户有效地管理信息流。随着技术的发展,传统的内容管理平台逐渐暴露出灵活性不足、易用性差等局限性,这促使市场需…...
100.1 AI量化面试题:解释夏普比率(Sharpe Ratio)的计算方法及其在投资组合管理中的应用,并说明其局限性
目录 0. 承前1. 夏普比率的基本概念1.1 定义与计算方法1.2 实际计算示例 2. 在投资组合管理中的应用2.1 投资组合选择2.2 投资组合优化 3. 夏普比率的局限性3.1 统计假设的限制3.2 实践中的问题 4. 改进方案4.1 替代指标4.2 实践建议 5. 回答话术 0. 承前 如果想更加全面清晰地…...
Ubuntu 下 nginx-1.24.0 源码分析 ngx_debug_init();
目录 ngx_debug_init() 函数: NGX_LINUX 的定义: ngx_debug_init() 函数: ngx_debug_init() 函数定义在 src\os\unix 目录下的 ngx_linux_config.h 中 #define ngx_debug_init() 也就是说这个环境下的 main 函数中的 ngx_debug_init() 这…...
基于人脸识别的课堂考勤系统
该项目是一个基于人脸识别的课堂考勤系统,使用Python开发,结合了多种技术实现考勤功能。要开发类似的基于人脸识别的考勤系统,可参考以下步骤: 环境搭建:利用Anaconda创建虚拟环境,指定Python版本为3.8&am…...
开启 AI 学习之旅:从入门到精通
最近 AI 真的超火,不管是工作还是生活里,到处都能看到它的身影。好多小伙伴都跑来问我,到底该怎么学 AI 呢?今天我就把自己学习 AI 的经验和心得分享出来,希望能帮到想踏入 AI 领域的朋友们! 一、学习内容有哪些 (一)编程语言 Python 绝对是首选!它在 AI 领域的生态…...
13 尺寸结构模块(size.rs)
一、size.rs源码 // Copyright 2013 The Servo Project Developers. See the COPYRIGHT // file at the top-level directory of this distribution. // // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or // http://www.apache.org/licenses/LICENSE…...
16.[前端开发]Day16-HTML+CSS阶段练习(网易云音乐五)
完整代码 网易云-main-left-rank(排行榜) <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name&q…...
ARM嵌入式学习--第十天(UART)
--UART介绍 UART(Universal Asynchonous Receiver and Transmitter)通用异步接收器,是一种通用串行数据总线,用于异步通信。该总线双向通信,可以实现全双工传输和接收。在嵌入式设计中,UART用来与PC进行通信,包括与监控…...
MoonBit 编译器(留档学习)
MoonBit 编译器 MoonBit 是一个用户友好,构建快,产出质量高的编程语言。 MoonBit | Documentation | Tour | Core This is the source code repository for MoonBit, a programming language that is user-friendly, builds fast, and produces high q…...
【TypeScript】基础:数据类型
文章目录 TypeScript一、简介二、类型声明三、数据类型anyunknownnervervoidobjecttupleenumType一些特殊情况 TypeScript 是JavaScript的超集,代码量比JavaScript复杂、繁多;但是结构更清晰 一、简介 为什么需要TypeScript? JavaScript的…...
Unity游戏(Assault空对地打击)开发(3) 摄像机的控制
详细步骤 打开My Assets或者Package Manager。 选择Unity Registry。 搜索Cinemachine,找到 Cinemachine包,点击 Install按钮进行安装。 关闭窗口,新建一个FreeLook Camera,如下。 接着新建一个对象Pos,拖到Player下面…...
【HarmonyOS之旅】基于ArkTS开发(三) -> 兼容JS的类Web开发(三)
目录 1 -> 生命周期 1.1 -> 应用生命周期 1.2 -> 页面生命周期 2 -> 资源限定与访问 2.1 -> 资源限定词 2.2 -> 资源限定词的命名要求 2.3 -> 限定词与设备状态的匹配规则 2.4 -> 引用JS模块内resources资源 3 -> 多语言支持 3.1 -> 定…...
小程序-基础加强-自定义组件
前言 这次讲自定义组件 1. 准备今天要用到的项目 2. 初步创建并使用自定义组件 这样就成功在home中引入了test组件 在json中引用了这个组件才能用这个组件 现在我们来实现全局引用组件 在app.json这样使用就可以了 3. 自定义组件的样式 发现页面里面的文本和组件里面的文…...
尝试ai生成figma设计
当听到用ai 自动生成figma设计时,不免好奇这个是如何实现的。在查阅了不少资料后,有了一些想法。参考了:在figma上使用脚本自动生成色谱 这篇文章提供的主要思路是:可以通过脚本的方式构建figma设计。如果我们使用ai 生成figma脚本…...
【周易哲学】生辰八字入门讲解(八)
😊你好,我是小航,一个正在变秃、变强的文艺倾年。 🔔本文讲解【周易哲学】生辰八字入门讲解,期待与你一同探索、学习、进步,一起卷起来叭! 目录 一、六亲女命六亲星六亲宫位相互关系 男命六亲星…...
康德哲学与自组织思想的渊源:从《判断力批判》到系统论的桥梁
康德哲学与自组织思想的渊源:从《判断力批判》到系统论的桥梁 第一节:康德哲学中的自然目的论与自组织思想 核心内容: 康德哲学中的自然目的论和反思判断力概念,为现代系统论中的自组织思想提供了哲学基础,预见了复…...
解决whisper 本地运行时GPU 利用率不高的问题
我在windows 环境下本地运行whisper 模型,使用的是nivdia RTX4070 显卡,结果发现GPU 的利用率只有2% 。使用 import torch print(torch.cuda.is_available()) 返回TRUE。表示我的cuda 是可用的。 最后在github 的下列网页上找到了问题 极低的 GPU 利…...
【自开发工具介绍】SQLSERVER的ImpDp和ExpDp工具02
工具运行前的环境准备 1、登录用户管理员权限确认 工具使用的登录用户(-u后面的用户),必须具有管理员的权限,因为需要读取系统表 例:Export.bat -s 10.48.111.12 -d db1 -u test -p test -schema dbo 2、Powershell的安全策略确认…...
