当前位置: 首页 > article >正文

使用 PyTorch 实现逻辑回归并评估模型性能

1. 逻辑回归简介

逻辑回归是一种用于解决二分类问题的算法。它通过一个逻辑函数(Sigmoid 函数)将线性回归的输出映射到 [0, 1] 区间内,从而将问题转化为概率预测问题。如果预测概率大于 0.5,则将样本分类为正类;否则分类为负类。

2. 数据准备

为了演示逻辑回归的效果,我们构造了一个简单的二维数据集,包含两类样本。每类样本有 7 个数据点,特征维度为 2。

class1_points = np.array([[1.9, 1.2],[1.5, 2.1],[1.9, 0.5],[1.5, 0.9],[0.9, 1.2],[1.1, 1.7],[1.4, 1.1]])class2_points = np.array([[3.2, 3.2],[3.7, 2.9],[3.2, 2.6],[1.7, 3.3],[3.4, 2.6],[4.1, 2.3],[3.0, 2.9]])

我们将这两类数据点的特征合并,并为每个数据点分配标签(0 表示第一类,1 表示第二类)。

3. 模型构建

我们使用 PyTorch 框架来实现逻辑回归模型。模型结构非常简单,仅包含一个线性层和一个 Sigmoid 激活函数。

class LogisticRegression(nn.Module):def __init__(self):super(LogisticRegression, self).__init__()self.linear = nn.Linear(2, 1)  # 输入特征维度为 2,输出为 1def forward(self, x):return torch.sigmoid(self.linear(x))

4. 模型训练

我们使用二分类交叉熵损失函数(BCELoss)和随机梯度下降优化器(SGD)来训练模型。训练过程如下:

epochs = 5000
for epoch in range(epochs):model.train()optimizer.zero_grad()outputs = model(X)loss = criterion(outputs, y)loss.backward()optimizer.step()if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}')

训练过程中,我们每 100 个 epoch 打印一次损失值,以便观察模型的收敛情况。

5. 模型保存与加载

训练完成后,我们将模型的权重保存到文件中,方便后续加载和使用。

torch.save(model.state_dict(), 'model3.pth')
print("模型已保存")

加载模型时,我们创建一个新的模型实例,并使用 load_state_dict 方法加载保存的权重。

loaded_model = LogisticRegression()
loaded_model.load_state_dict(torch.load('model3.pth', map_location=torch.device('cpu')))
loaded_model.eval()

6. 模型预测与性能评估

加载模型后,我们使用模型对训练数据进行预测,并计算精确度、召回率和 F1 分数。

with torch.no_grad():predictions = loaded_model(X)predicted_labels = (predictions > 0.5).float()print("实际结果:", y.numpy().flatten())
print("预测结果:", predicted_labels.numpy().flatten())precision = precision_score(y.numpy().flatten(), predicted_labels.numpy().flatten())
recall = recall_score(y.numpy().flatten(), predicted_labels.numpy().flatten())
f1 = f1_score(y.numpy().flatten(), predicted_labels.numpy().flatten())print(f"精确度(Precision): {precision:.4f}")
print(f"召回率(Recall): {recall:.4f}")
print(f"F1 分数: {f1:.4f}")

7. 运行结果

运行上述代码后,我们得到了以下结果:

  • 实际结果:[0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1.]

  • 预测结果:[0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1.]

  • 精确度(Precision):1.0000

  • 召回率(Recall):1.0000

  • F1 分数:1.0000

从结果可以看出,模型在训练集上表现良好,精确度、召回率和 F1 分数均为 1.0000。

8. 完整代码

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import precision_score, recall_score, f1_score"""使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测,对预测结果计算精确度和召回率及F1分数"""
# 提取特征和标签
class1_points = np.array([[1.9, 1.2],[1.5, 2.1],[1.9, 0.5],[1.5, 0.9],[0.9, 1.2],[1.1, 1.7],[1.4, 1.1]])class2_points = np.array([[3.2, 3.2],[3.7, 2.9],[3.2, 2.6],[1.7, 3.3],[3.4, 2.6],[4.1, 2.3],[3.0, 2.9]])# 提取两类特征,输入特征维度为2
x1_data = np.concatenate((class1_points[:, 0], class2_points[:, 0]), axis=0)
x2_data = np.concatenate((class1_points[:, 1], class2_points[:, 1]), axis=0)
label = np.concatenate((np.zeros(len(class1_points)), np.ones(len(class2_points))), axis=0)# 将数据转换为 PyTorch 张量
X = torch.tensor(np.column_stack((x1_data, x2_data)), dtype=torch.float32)
y = torch.tensor(label, dtype=torch.float32).view(-1, 1)# 定义逻辑回归模型
class LogisticRegression(nn.Module):def __init__(self):super(LogisticRegression, self).__init__()self.linear = nn.Linear(2, 1)  # 输入特征维度为 2,输出为 1def forward(self, x):return torch.sigmoid(self.linear(x))# 初始化模型、损失函数和优化器
model = LogisticRegression()
criterion = nn.BCELoss()  # 二分类交叉熵损失
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练模型
epochs = 5000
for epoch in range(epochs):model.train()optimizer.zero_grad()outputs = model(X)loss = criterion(outputs, y)loss.backward()optimizer.step()if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}')# 保存模型
torch.save(model.state_dict(), 'model3.pth')
print("模型已保存")# 加载模型
loaded_model = LogisticRegression()
loaded_model.load_state_dict(torch.load('model3.pth',map_location=torch.device('cpu'),weights_only=True))
loaded_model.eval()# 进行预测
with torch.no_grad():predictions = loaded_model(X)predicted_labels = (predictions > 0.5).float()# 展示预测结果和实际结果
print("实际结果:", y.numpy().flatten())
print("预测结果:", predicted_labels.numpy().flatten())# 计算精确度、召回率和 F1 分数
precision = precision_score(y.numpy().flatten(), predicted_labels.numpy().flatten())
recall = recall_score(y.numpy().flatten(), predicted_labels.numpy().flatten())
f1 = f1_score(y.numpy().flatten(), predicted_labels.numpy().flatten())print(f"精确度(Precision): {precision:.4f}")
print(f"召回率(Recall): {recall:.4f}")
print(f"F1 分数: {f1:.4f}")

相关文章:

使用 PyTorch 实现逻辑回归并评估模型性能

1. 逻辑回归简介 逻辑回归是一种用于解决二分类问题的算法。它通过一个逻辑函数(Sigmoid 函数)将线性回归的输出映射到 [0, 1] 区间内,从而将问题转化为概率预测问题。如果预测概率大于 0.5,则将样本分类为正类;否则分…...

python学opencv|读取图像(五十二)使用cv.matchTemplate()函数实现最佳图像匹配

【1】引言 前序学习了图像的常规读取和基本按位操作技巧,相关文章包括且不限于: python学opencv|读取图像-CSDN博客 python学opencv|读取图像(四十九)原理探究:使用cv2.bitwise()系列函数实现图像按位运算-CSDN博客…...

【VUE案例练习】前端vue2+element-ui,后端nodo+express实现‘‘文件上传/删除‘‘功能

近期在做跟毕业设计相关的数据后台管理系统,其中的列表项展示有图片展示,添加/编辑功能有文件上传。 “文件上传/删除”也是我们平时开发会遇到的一个功能,这里分享个人的实现过程,与大家交流谈论~ 一、准备工作 本次案例使用的…...

使用真实 Elasticsearch 进行高级集成测试

作者:来自 Elastic Piotr Przybyl 掌握高级 Elasticsearch 集成测试:更快、更智能、更优化。 在上一篇关于集成测试的文章中,我们介绍了如何通过改变数据初始化策略来缩短依赖于真实 Elasticsearch 的集成测试的执行时间。在本期中&#xff0…...

【R语言】函数

一、函数格式 如下所示: hello:函数名;function:定义的R对象是函数而不是其它变量;():函数的输入参数,可以为空,也可以包含参数;{}:函数体,如果…...

Vue 3 30天精进之旅:Day 12 - 异步操作

在现代前端开发中,异步操作是一个非常常见的需求,例如从后端API获取数据、进行文件上传等任务。Vue 3 结合组合式API和Vuex可以方便地处理这些异步操作。今天我们将重点学习如何在Vue应用中进行异步操作,包括以下几个主题: 异步操…...

VSCode插件Live Server

简介:插件Live Server能够实现当我们在VSCode编辑器里修改 HTML、CSS 或者 JavaScript 文件时,它都能自动实时地刷新浏览器页面,让我们实时看到代码变化的效果。再也不用手动刷新浏览器了,节省了大量的开发过程耗时! 1…...

50. 正点原子官方系统镜像烧写实验

一、Windows下使用OTG烧写系统 1、在Windos使用NXP提供的mfgtool来向开发烧写系统。需要用先将开发板的USB_OTG接口连接到电脑上。 Mfgtool工具是向板子先下载一个Linux系统,然后通过这个系统来完成烧写工作。 切记!使用OTG烧写的时候要先把SD卡拔出来&…...

在C#中,什么是多态如何实现

在C#中,什么是多态?如何实现? C#中的多态性 多态性是面向对象编程的一个核心概念,他允许对象以多种形式表现.在C#中,多态主要通过虚方法,抽象方法和接口来实现. 多态性的存在使得同一个行为可以有多个不同的表达形式 即同一个接口可以使用不同的实例来执行不同的操作 虚方…...

搜索引擎友好:设计快速收录的网站架构

本文来自:百万收录网 原文链接:https://www.baiwanshoulu.com/14.html 为了设计一个搜索引擎友好的网站架构,以实现快速收录,可以从以下几个方面入手: 一、清晰的目录结构与层级 合理划分内容:目录结构应…...

扩散模型(三)

相关阅读: 扩散模型(一) 扩散模型(二) Latent Variable Space 潜在扩散模型(LDM;龙巴赫、布拉特曼等人,2022 年)在潜在空间而非像素空间中运行扩散过程,这…...

it基础使用--5---git远程仓库

it基础使用–5—git远程仓库 1. 按顺序看 -git基础使用–1–版本控制的基本概念 -git基础使用–2–gti的基本概念 -git基础使用–3—安装和基本使用 -git基础使用–4—git分支和使用 2. 什么是远程仓库 在第一篇文章中,我们已经讲过了远程仓库,每个本…...

Baklib如何改变内容管理平台的未来推动创新与效率提升

内容概要 在信息爆炸的时代,内容管理平台成为了企业和个人不可或缺的工具。它通过高效组织、存储和发布内容,帮助用户有效地管理信息流。随着技术的发展,传统的内容管理平台逐渐暴露出灵活性不足、易用性差等局限性,这促使市场需…...

100.1 AI量化面试题:解释夏普比率(Sharpe Ratio)的计算方法及其在投资组合管理中的应用,并说明其局限性

目录 0. 承前1. 夏普比率的基本概念1.1 定义与计算方法1.2 实际计算示例 2. 在投资组合管理中的应用2.1 投资组合选择2.2 投资组合优化 3. 夏普比率的局限性3.1 统计假设的限制3.2 实践中的问题 4. 改进方案4.1 替代指标4.2 实践建议 5. 回答话术 0. 承前 如果想更加全面清晰地…...

Ubuntu 下 nginx-1.24.0 源码分析 ngx_debug_init();

目录 ngx_debug_init() 函数: NGX_LINUX 的定义: ngx_debug_init() 函数: ngx_debug_init() 函数定义在 src\os\unix 目录下的 ngx_linux_config.h 中 #define ngx_debug_init() 也就是说这个环境下的 main 函数中的 ngx_debug_init() 这…...

基于人脸识别的课堂考勤系统

该项目是一个基于人脸识别的课堂考勤系统,使用Python开发,结合了多种技术实现考勤功能。要开发类似的基于人脸识别的考勤系统,可参考以下步骤: 环境搭建:利用Anaconda创建虚拟环境,指定Python版本为3.8&am…...

开启 AI 学习之旅:从入门到精通

最近 AI 真的超火,不管是工作还是生活里,到处都能看到它的身影。好多小伙伴都跑来问我,到底该怎么学 AI 呢?今天我就把自己学习 AI 的经验和心得分享出来,希望能帮到想踏入 AI 领域的朋友们! 一、学习内容有哪些 (一)编程语言 Python 绝对是首选!它在 AI 领域的生态…...

13 尺寸结构模块(size.rs)

一、size.rs源码 // Copyright 2013 The Servo Project Developers. See the COPYRIGHT // file at the top-level directory of this distribution. // // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or // http://www.apache.org/licenses/LICENSE…...

16.[前端开发]Day16-HTML+CSS阶段练习(网易云音乐五)

完整代码 网易云-main-left-rank&#xff08;排行榜&#xff09; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name&q…...

ARM嵌入式学习--第十天(UART)

--UART介绍 UART(Universal Asynchonous Receiver and Transmitter)通用异步接收器&#xff0c;是一种通用串行数据总线&#xff0c;用于异步通信。该总线双向通信&#xff0c;可以实现全双工传输和接收。在嵌入式设计中&#xff0c;UART用来与PC进行通信&#xff0c;包括与监控…...

MoonBit 编译器(留档学习)

MoonBit 编译器 MoonBit 是一个用户友好&#xff0c;构建快&#xff0c;产出质量高的编程语言。 MoonBit | Documentation | Tour | Core This is the source code repository for MoonBit, a programming language that is user-friendly, builds fast, and produces high q…...

【TypeScript】基础:数据类型

文章目录 TypeScript一、简介二、类型声明三、数据类型anyunknownnervervoidobjecttupleenumType一些特殊情况 TypeScript 是JavaScript的超集&#xff0c;代码量比JavaScript复杂、繁多&#xff1b;但是结构更清晰 一、简介 为什么需要TypeScript&#xff1f; JavaScript的…...

Unity游戏(Assault空对地打击)开发(3) 摄像机的控制

详细步骤 打开My Assets或者Package Manager。 选择Unity Registry。 搜索Cinemachine&#xff0c;找到 Cinemachine包&#xff0c;点击 Install按钮进行安装。 关闭窗口&#xff0c;新建一个FreeLook Camera&#xff0c;如下。 接着新建一个对象Pos&#xff0c;拖到Player下面…...

【HarmonyOS之旅】基于ArkTS开发(三) -> 兼容JS的类Web开发(三)

目录 1 -> 生命周期 1.1 -> 应用生命周期 1.2 -> 页面生命周期 2 -> 资源限定与访问 2.1 -> 资源限定词 2.2 -> 资源限定词的命名要求 2.3 -> 限定词与设备状态的匹配规则 2.4 -> 引用JS模块内resources资源 3 -> 多语言支持 3.1 -> 定…...

小程序-基础加强-自定义组件

前言 这次讲自定义组件 1. 准备今天要用到的项目 2. 初步创建并使用自定义组件 这样就成功在home中引入了test组件 在json中引用了这个组件才能用这个组件 现在我们来实现全局引用组件 在app.json这样使用就可以了 3. 自定义组件的样式 发现页面里面的文本和组件里面的文…...

尝试ai生成figma设计

当听到用ai 自动生成figma设计时&#xff0c;不免好奇这个是如何实现的。在查阅了不少资料后&#xff0c;有了一些想法。参考了&#xff1a;在figma上使用脚本自动生成色谱 这篇文章提供的主要思路是&#xff1a;可以通过脚本的方式构建figma设计。如果我们使用ai 生成figma脚本…...

【周易哲学】生辰八字入门讲解(八)

&#x1f60a;你好&#xff0c;我是小航&#xff0c;一个正在变秃、变强的文艺倾年。 &#x1f514;本文讲解【周易哲学】生辰八字入门讲解&#xff0c;期待与你一同探索、学习、进步&#xff0c;一起卷起来叭&#xff01; 目录 一、六亲女命六亲星六亲宫位相互关系 男命六亲星…...

康德哲学与自组织思想的渊源:从《判断力批判》到系统论的桥梁

康德哲学与自组织思想的渊源&#xff1a;从《判断力批判》到系统论的桥梁 第一节&#xff1a;康德哲学中的自然目的论与自组织思想 核心内容&#xff1a; 康德哲学中的自然目的论和反思判断力概念&#xff0c;为现代系统论中的自组织思想提供了哲学基础&#xff0c;预见了复…...

解决whisper 本地运行时GPU 利用率不高的问题

我在windows 环境下本地运行whisper 模型&#xff0c;使用的是nivdia RTX4070 显卡&#xff0c;结果发现GPU 的利用率只有2% 。使用 import torch print(torch.cuda.is_available()) 返回TRUE。表示我的cuda 是可用的。 最后在github 的下列网页上找到了问题 极低的 GPU 利…...

【自开发工具介绍】SQLSERVER的ImpDp和ExpDp工具02

工具运行前的环境准备 1、登录用户管理员权限确认 工具使用的登录用户(-u后面的用户)&#xff0c;必须具有管理员的权限&#xff0c;因为需要读取系统表 例&#xff1a;Export.bat -s 10.48.111.12 -d db1 -u test -p test -schema dbo      2、Powershell的安全策略确认…...