当前位置: 首页 > article >正文

PyTorch实现逻辑回归:从原理到实战

1. 逻辑回归基础与PyTorch实现概览逻辑回归是机器学习中最基础但极其重要的分类算法尽管名字中带有回归它实际上解决的是二分类问题。在PyTorch框架下实现逻辑回归不仅能理解深度学习的基础构建块还能掌握自定义模型的核心方法。关键理解逻辑回归本质是在线性回归的输出上套用sigmoid函数将任意实数映射到(0,1)区间解释为概率值。当概率0.5时预测为正类否则为负类。1.1 为什么选择PyTorch实现PyTorch的动态计算图特性使得模型开发和调试过程非常直观即时执行模式操作结果立即可见便于理解数据流动自动微分系统无需手动实现反向传播模块化设计nn.Module基类提供标准的模型封装方式GPU加速只需简单.to(device)即可迁移计算设备import torch import torch.nn as nn # 基础检查验证环境配置 print(fPyTorch版本: {torch.__version__}) print(fCUDA可用: {torch.cuda.is_available()})2. 核心组件实现详解2.1 Sigmoid函数原理与可视化Sigmoid函数的数学表达式为 $$ \sigma(z) \frac{1}{1e^{-z}} $$其特性包括将输入压缩到(0,1)区间在z0处斜率最大两端梯度趋于平缓可能引发梯度消失import matplotlib.pyplot as plt def plot_sigmoid(): z torch.arange(-10, 10, 0.1) sigmoid nn.Sigmoid() plt.figure(figsize(10, 5)) plt.plot(z.numpy(), sigmoid(z).numpy(), labelSigmoid) plt.axvline(0, colorr, linestyle--, alpha0.3) plt.axhline(0.5, colorr, linestyle--, alpha0.3) plt.xlabel(Input value (z)) plt.ylabel(Sigmoid output) plt.title(Sigmoid Function Curve) plt.grid(True) plt.legend() plt.show() plot_sigmoid()2.2 两种模型构建方式对比方案A使用nn.Sequential快速搭建sequential_model nn.Sequential( nn.Linear(in_features1, out_features1), nn.Sigmoid() )优势代码简洁适合简单模型层间自动传递数据参数自动初始化方案B自定义nn.Module子类class LogisticRegression(nn.Module): def __init__(self, input_dim): super().__init__() self.linear nn.Linear(input_dim, 1) def forward(self, x): return torch.sigmoid(self.linear(x))优势灵活控制前向传播逻辑可添加自定义方法便于复杂模型扩展实际选择建议对于生产环境推荐使用自定义类方式虽然代码量稍多但更易维护和扩展。3. 完整训练流程实现3.1 数据准备与加载构建一个模拟的二分类数据集def generate_data(n_samples1000): torch.manual_seed(42) X torch.randn(n_samples, 2) * 1.5 # 创建分类边界线性可分 y ((X[:, 0] X[:, 1]) 0).float() return X, y.unsqueeze(1) X, y generate_data() print(f特征形状: {X.shape}, 标签形状: {y.shape}) # 数据集可视化 plt.scatter(X[:,0], X[:,1], cy.squeeze(), cmapbwr, alpha0.6) plt.xlabel(Feature 1) plt.ylabel(Feature 2) plt.title(Generated Classification Data) plt.colorbar() plt.show()3.2 训练循环实现def train_model(model, X, y, epochs1000, lr0.01): criterion nn.BCELoss() # 二分类交叉熵损失 optimizer torch.optim.SGD(model.parameters(), lrlr) losses [] for epoch in range(epochs): # 前向传播 outputs model(X) loss criterion(outputs, y) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() losses.append(loss.item()) if (epoch1) % 100 0: print(fEpoch [{epoch1}/{epochs}], Loss: {loss.item():.4f}) return losses # 实例化模型 model LogisticRegression(input_dim2) loss_history train_model(model, X, y) # 绘制损失曲线 plt.plot(loss_history) plt.xlabel(Epoch) plt.ylabel(Loss) plt.title(Training Loss Curve) plt.grid(True) plt.show()3.3 模型评估与决策边界def plot_decision_boundary(model, X, y): # 创建网格点 x_min, x_max X[:,0].min()-1, X[:,0].max()1 y_min, y_max X[:,1].min()-1, X[:,1].max()1 xx, yy torch.meshgrid(torch.linspace(x_min, x_max, 100), torch.linspace(y_min, y_max, 100)) # 预测网格点类别 with torch.no_grad(): Z model(torch.cat([xx.reshape(-1,1), yy.reshape(-1,1)], dim1)) Z Z.reshape(xx.shape) 0.5 # 绘制结果 plt.contourf(xx.numpy(), yy.numpy(), Z.numpy(), alpha0.3, cmapbwr) plt.scatter(X[:,0], X[:,1], cy.squeeze(), cmapbwr, edgecolorsk) plt.title(Decision Boundary) plt.xlabel(Feature 1) plt.ylabel(Feature 2) plt.show() plot_decision_boundary(model, X, y)4. 实战技巧与问题排查4.1 常见问题解决方案问题现象可能原因解决方案损失不下降学习率设置不当尝试0.1, 0.01, 0.001等不同学习率预测全为0或1数据不平衡使用class_weight或重采样梯度爆炸输入值范围过大标准化输入特征准确率波动大批量大小太小增大batch_size或使用全批量4.2 性能优化技巧数据预处理标准化from sklearn.preprocessing import StandardScaler scaler StandardScaler() X_scaled torch.FloatTensor(scaler.fit_transform(X))学习率调度scheduler torch.optim.lr_scheduler.StepLR(optimizer, step_size30, gamma0.1) # 在训练循环中添加 scheduler.step()早停机制best_loss float(inf) patience 10 counter 0 for epoch in range(epochs): # ...训练代码... if loss.item() best_loss: best_loss loss.item() counter 0 else: counter 1 if counter patience: print(Early stopping triggered) break4.3 模型保存与加载# 保存整个模型 torch.save(model, logistic_regression.pth) # 仅保存参数推荐 torch.save(model.state_dict(), lr_params.pth) # 加载模型 loaded_model LogisticRegression(input_dim2) loaded_model.load_state_dict(torch.load(lr_params.pth)) loaded_model.eval() # 设置为评估模式5. 进阶扩展方向5.1 多分类逻辑回归通过修改输出层实现多分类class MulticlassLogisticRegression(nn.Module): def __init__(self, input_dim, num_classes): super().__init__() self.linear nn.Linear(input_dim, num_classes) def forward(self, x): return torch.softmax(self.linear(x), dim1)5.2 正则化应用L2正则化权重衰减optimizer torch.optim.SGD(model.parameters(), lr0.01, weight_decay0.1)5.3 GPU加速实现device torch.device(cuda if torch.cuda.is_available() else cpu) model LogisticRegression(input_dim2).to(device) X, y X.to(device), y.to(device)在实际项目中逻辑回归往往作为基线模型出现。虽然结构简单但理解其PyTorch实现能帮助我们掌握深度学习模型的核心构建模式。当遇到更复杂模型时这些基础技术会派上大用场。

相关文章:

PyTorch实现逻辑回归:从原理到实战

1. 逻辑回归基础与PyTorch实现概览逻辑回归是机器学习中最基础但极其重要的分类算法,尽管名字中带有"回归",它实际上解决的是二分类问题。在PyTorch框架下实现逻辑回归,不仅能理解深度学习的基础构建块,还能掌握自定义模…...

RAGFlow · 第 3 章:第一节 RAGFlow 配置参数全景图与实验结论

系列导航 第 0 章 前言:为什么企业 AI 工程师必须掌握 RAGFlow第 1 章:安装部署与基础配置**——从零跑通第一个 RAG Pipeline第 2 章:RAGFlow RAGFlow 代码介绍第 3 章:攻克企业复杂文档——理解 DeepDoc、Naive、MinerU 与 Docl…...

NVIDIA Nemotron 3架构解析:智能体AI与混合Mamba-Transformer MoE设计

1. NVIDIA Nemotron 3架构解析:面向智能体AI的新一代模型设计在当今AI领域,智能体系统(Agentic AI)正变得越来越复杂。这类系统通常由多个协作的智能体组成——包括检索器、规划器、工具执行器和验证器等——它们需要在大量上下文…...

AI 时代最大的谎言:你以为在学习,其实在欠债—思维决定上限的反焦虑框架

文章目录1、写在前面:我为什么不再写"AI 焦虑"2、本文速览3、AI 焦虑的真实闭环:你不是在错过 AI3.1、焦虑的来源不是机会,是怕3.2、机会从来不属于"绝大多数人"3.3、对你的实际意义4、MIT 认知负债:所有 AI …...

每日一学:设计模式之观察者模式

观察者模式(Observer Pattern)属于行为型设计模式,核心定义:构建对象间一对多的依赖关系,当被观察者(发布者 / 主题)状态发生变化时,所有订阅它的观察者(订阅者&#xff…...

【2026年网易雷火春招- 4月26日-第一题- 喵居】(题目+思路+JavaC++Python解析+在线测试)

题目内容 在《忘川风华录》的喵居中,为了帮助名士猫完成进化,使君需要炼化出高阶的九世灵。 喵居的供台上目前散落着 nnn 团微小的「猫灵元魂」,第 iii 团元魂的灵力值为 aia_i...

Bluetooth Classic中的速率区别

0 Preface/Foreword1PHY介绍1.1 与BLE的区别BLE有PHY 1M和2M的区别,但是在Bluetooth Classic中,没有这个概念。因为PHY 1M和2M是BLE的专有术语。虽然BLE和Bluetooth Classic都是使用2.4GHz,但是走的两套不同的技术路线。1.2 PHY速率分类Bluet…...

智能电话录音总结,工具高精准识别快速整理,复盘通话超省心省事

最近试了2026年新迭代的这批智能电话录音总结工具,高精准识别加快速整理是真的香,现在复盘通话完全不用再熬大夜来回拖进度条扒内容,省心到我恨不得早两年用上。我做To B销售快三年,之前最头疼的就是每天打七八通客户电话&#xf…...

高效编程实践:用Codex告别重复造轮子

技术文章大纲:告别重复造轮子——Codex写脚本的高效实践核心概念与背景重复造轮子的定义:开发中重复实现已有功能的现象及其效率问题Codex的定位:AI辅助编程工具如何通过自然语言生成代码适用场景:快速原型开发、自动化脚本、代码…...

ChatGPT-CLI:终端集成AI助手,提升开发者效率的实战指南

1. 项目概述:一个让ChatGPT在终端里“安家”的命令行工具如果你和我一样,每天大部分时间都泡在终端(Terminal)里,那么你一定有过这样的体验:为了向ChatGPT提个问题,或者让它帮忙写段代码&#x…...

如何搭建逻辑备库_SQL Apply与不支持的数据类型评估

SQL Apply 启动失败主因是备库控制文件残留主库“只读”标记或角色未正确设为PHYSICAL STANDBY;需确保V$DATABASE中DATABASE_ROLEPHYSICAL STANDBY且OPEN_MODEMOUNTED,并清理V$DATAGUARD_CONFIG中重复DB_UNIQUE_NAME。SQL Apply 启动失败报 ORA-16000 或…...

华为HDC大会2024张平安总keynote盘古多模态生成大模型:STCG技术如何重塑自动驾驶数据引擎

从"娱乐生成"到"产业生成":盘古的差异化路径 当业界多模态大模型还在追逐一镜到底的娱乐视频生成时,盘古5.0选择了一条截然不同的技术路线——聚焦行业急需的价值场景。在华为HDC大会上,盘古团队首次系统披露了多模态生…...

GEEKOM GT1 Mega迷你主机Ubuntu 24.10性能评测

1. GEEKOM GT1 Mega迷你主机深度评测:Ubuntu 24.10下的Intel Core Ultra 9 185H体验 作为一名长期关注迷你主机的技术爱好者,最近我有机会对搭载Intel Core Ultra 9 185H处理器的GEEKOM GT1 Mega进行了全面测试。这款迷你主机在Windows 11 Pro环境下表现…...

Transformer和LLM前沿内容(4):Long-Context LLM

文章目录1. Context Extension1.1 Rotary Position Embedding (RoPE)1.2 LongLoRA2. Evaluation of Long-Context LLMs2.1 The Lost in the Middle Phenomenon2.2 Long-Context Benchmarks: NIAH, LongBench3. Efficient Attention Mechanisms3.1 KV Cache3.2 StreamingLLM and…...

YLB3118 × DeepSeek V4@ACP#国产存储控制芯片,筑牢大模型推理的 “数据基石”

在国产 AI 大模型加速落地的浪潮中,DeepSeek V4 凭借万亿级参数、百万级上下文窗口的硬核实力,成为开源大模型的标杆;而YLB3118 作为国产 PCIe 转 SATA 存储控制芯片的核心代表,以高密度扩展、低功耗、工业级可靠的特性&#xff0…...

VMware+RockyLinux10

VMwareRocky Linux 10 1、官网下载 2、安装 3、配置VMware部分 下载 VMware官方网站:https://www.vmware.com 目前只做宣传,无下载入口 可以下载到的官网:https://support.broadcom.com/group/ecx/free-downloads 右上角Login用Broadcom Supp…...

PE-bear深度解析:跨平台PE文件分析的瑞士军刀

PE-bear深度解析:跨平台PE文件分析的瑞士军刀 【免费下载链接】pe-bear Portable Executable reversing tool with a friendly GUI 项目地址: https://gitcode.com/gh_mirrors/pe/pe-bear 在逆向工程和恶意软件分析领域,PE文件分析工具是安全研究…...

齐纳二极管稳压原理与工程应用全解析

1. 齐纳二极管稳压原理深度解析 齐纳二极管(Zener Diode)作为电子电路中最经典的电压基准元件,其核心工作原理建立在PN结的反向击穿特性上。当反向电压达到特定阈值(VZ)时,二极管进入击穿区,此时…...

MusicPlayer2完全指南:10个技巧让你的Windows音乐体验焕然一新

MusicPlayer2完全指南:10个技巧让你的Windows音乐体验焕然一新 【免费下载链接】MusicPlayer2 MusicPlayer2是一款功能强大的本地音乐播放软件,旨在为用户提供最佳的本地音乐播放体验。它支持歌词显示、歌词卡拉OK样式显示、歌词在线下载、歌词编辑、歌曲…...

SVM与拉格朗日乘子法:从原理到Python实现

1. 从理论到实践:理解SVM与拉格朗日乘子法的本质支持向量机(SVM)作为机器学习领域的经典算法,其核心思想来源于统计学习理论和凸优化方法。我在实际项目中多次使用SVM解决分类问题,发现真正理解其背后的数学原理&#…...

Mysql的源码编译

1.下载安装包wget https://downloads.mysql.com/archives/get/p/23/file/mysql-boost-8.3.0.tar.gz2.源码编译​ [rootmysql-node1 ~]# dnf install cmake3 gcc git bison openssl-devel ncurses-devel systemd-devel rpcgen.x86_64 libtirpc-devel-1.3.3-9.el9.x86_64.rpm gc…...

5个小众机器学习可视化工具提升模型解释力

1. 机器学习可视化工具的隐藏瑰宝在数据科学项目中,可视化从来不只是锦上添花——它直接决定了你的模型能否被非技术背景的决策者理解。虽然Matplotlib和Seaborn已经人尽皆知,但今天我要分享的这五个小众可视化库,能让你的机器学习故事讲述能…...

谷歌SEO如何做图标优化?

在谷歌搜索算法持续演进与用户体验标准不断提升的当下,网站技术SEO的精细化程度已成为影响排名与流量的关键因素。其中,图标(Icons)作为用户界面与品牌视觉识别的重要元素,其优化处理往往被忽视,却对网站性…...

利用Obsidian Local REST API构建可检索的AI对话知识库

1. 项目概述:在 Obsidian 中构建你的 AI 对话知识库如果你和我一样,日常重度依赖 Cursor 的 AI 编程助手来探讨技术方案、解决代码问题,那么一个痛点很快就会浮现:那些充满洞见的对话,在 Cursor 的聊天历史里翻找起来异…...

从‘酷女孩’到‘商务女性’:用Stable Diffusion + Lora 玩转AI人像风格化的实战心得

从‘酷女孩’到‘商务女性’:Stable Diffusion Lora 风格化人像生成实战指南 在数字艺术创作领域,AI生成技术正以前所未有的速度重塑着内容生产方式。作为一名长期深耕AI视觉创作的实践者,我深刻体会到Stable Diffusion配合Lora模型带来的创…...

MacBook Pro用户必看:M4芯片的38 TOPS Neural Engine,真能让Stable Diffusion本地跑得更快吗?

M4芯片加持下的MacBook Pro:Stable Diffusion本地运行实战指南 当苹果在春季发布会上骄傲地宣布M4芯片的Neural Engine达到38 TOPS算力时,整个创意社区都在问同一个问题:这能让我的MacBook真正流畅运行Stable Diffusion吗?作为每天…...

机器学习工程师职业指南:从入门到高薪就业

1. 为什么现在进入机器学习领域正当时? 十年前我第一次接触机器学习时,整个领域还停留在学术论文和实验室阶段。如今超市的智能结算系统、手机里的人脸解锁、邮箱里的垃圾邮件过滤,背后都是机器学习在发挥作用。这个转变不仅意味着技术成熟度…...

概率分布实战指南:从基础到应用

1. 概率分布入门指南概率分布就像天气预报中的降水概率图——它能告诉我们不同结果出现的可能性大小。作为数据分析、机器学习和统计建模的基础工具,理解概率分布相当于掌握了量化不确定性的语言。我在金融风控和AB测试领域工作十年,每天都要和各种分布打…...

AWS CDK构造库实战:快速构建生成式AI应用基础设施

1. 项目概述:当CDK遇上生成式AI 如果你正在用AWS构建生成式AI应用,并且已经厌倦了在控制台里手动点击、配置各种服务,或者在CloudFormation模板里反复调试那些复杂的IAM权限和网络配置,那么 awslabs/generative-ai-cdk-construc…...

开源规则引擎Ruler:解耦复杂业务逻辑的声明式编程实践

1. 项目概述与核心价值最近在折腾一些文档处理和自动化流程,发现一个挺有意思的开源项目,叫intellectronica/ruler。乍一看名字,你可能会联想到“尺子”或者“规则”,没错,它的核心功能就是帮你定义和执行一系列规则&a…...