当前位置: 首页 > news >正文

机器学习 - 提高模型 (代码)

如果模型出现了 underfitting 问题,就得提高模型了。

Model improvement techniqueWhat does it do?
Add more layersEach layer potentially increases the learning capabilities of the model with each layer being able to learn some kind of new pattern in the data, more layers is often referred to as making your neural network deeper.
Add more hidden unitsMore hidden units per layer means a potential increase in learning capabilities of the model, more hidden units is often referred to as making your neural network wider.
Fitting for longer (more epochs)Your model might learn more if it had more opportunities to look at the data.
Changing the activation functionsSome data just can’t be fit with only straight lines, using non-linear activation functions can help with this.
Change the learning rateLess model specific, but still related, the learning rate of the optimizer decides how much a model should change its parameter each step, too much and the model overcorrects, too little and it doesn’t learn enough.
Change the loss functionLess model specific but still important, different problems require different loss functions. For example, a binary cross entropy loss function won’t work with a multi-class classification problem.
Use transfer learningTake a pretrained model from a problem domain similar to yours and adjust it to your own problem.

举个例子,代码如下:

class CircleModelV1(nn.Module):def __init__(self):super().__init__()self.layer_1 = nn.Linear(in_features = 2, out_features = 10)self.layer_2 = nn.Linear(in_features = 10, out_features = 10)self.layer_3 = nn.Linear(in_features = 10, out_features = 1)def forward(self, x):return self.layer_3(self.layer_2(self.layer_1(x)))model_1 = CircleModelV1().to("cpu")
print(model_1)loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(model_1.parameters(), lr=0.1)torch.manual_seed(42)epochs = 1000X_train, y_train = X_train.to("cpu"), y_train.to("cpu")
X_test, y_test = X_test.to("cpu"), y_test.to("cpu")for epoch in range(epochs):### Training# 1. Forward pass y_logits = model_1(X_train).squeeze()y_pred = torch.round(torch.sigmoid(y_logits))  # logits -> probabilities -> prediction labels # 2. Calculate loss/accuracy loss = loss_fn(y_logits, y_train)acc = accuracy_fn(y_true = y_train, y_pred = y_pred)# 3. Optimizer zero grad optimizer.zero_grad()# 4. Loss backwards loss.backward()# 5. Optimizer step optimizer.step() ### Testing model_1.eval()with torch.inference_mode():# 1. Forward pass test_logits = model_1(X_test).squeeze()test_pred = torch.round(torch.sigmoid(test_logits))# 2. Calculate loss/accuracy test_loss = loss_fn(test_logits, y_test)test_acc = accuracy_fn(y_true = y_test, y_pred = test_pred)if epoch % 100 == 0:print(f"Epoch: {epoch} | Loss: {loss:.5f}, Accuracy: {acc:.2f}%")# 结果如下
CircleModelV1((layer_1): Linear(in_features=2, out_features=10, bias=True)(layer_2): Linear(in_features=10, out_features=10, bias=True)(layer_3): Linear(in_features=10, out_features=1, bias=True)
)
Epoch: 0 | Loss: 0.69528, Accuracy: 51.38%
Epoch: 100 | Loss: 0.69325, Accuracy: 47.88%
Epoch: 200 | Loss: 0.69309, Accuracy: 49.88%
Epoch: 300 | Loss: 0.69303, Accuracy: 50.50%
Epoch: 400 | Loss: 0.69300, Accuracy: 51.38%
Epoch: 500 | Loss: 0.69299, Accuracy: 51.12%
Epoch: 600 | Loss: 0.69298, Accuracy: 51.50%
Epoch: 700 | Loss: 0.69298, Accuracy: 51.38%
Epoch: 800 | Loss: 0.69298, Accuracy: 51.50%
Epoch: 900 | Loss: 0.69298, Accuracy: 51.38%

都看到这了,点个赞呗~

相关文章:

机器学习 - 提高模型 (代码)

如果模型出现了 underfitting 问题,就得提高模型了。 Model improvement techniqueWhat does it do?Add more layersEach layer potentially increases the learning capabilities of the model with each layer being able to learn some kind of new pattern in…...

数值代数及方程数值解:预备知识——二进制及浮点数

文章目录 二进制IEEE浮点数 本篇文章的前置知识:数学分析 二进制 命题:二进制转化为十进制 二进制的数字表示为 ⋯ b 2 b 1 b 0 . b − 1 b − 2 ⋯ \cdots b_2b_1b_0.b_{-1}b_{-2}\cdots ⋯b2​b1​b0​.b−1​b−2​⋯这等价于十进制下的 ⋯ b 2 2 …...

新数字时代的启示:揭开Web3的秘密之路

在当今数字时代,随着区块链技术的不断发展,Web3作为下一代互联网的概念正逐渐引起人们的关注和探索。本文将深入探讨新数字时代的启示,揭开Web3的神秘之路,并探讨其在未来的发展前景。 1. Web3的定义与特点 Web3是对互联网未来发…...

算法——动态规划:01背包

原始01背包见下面这篇文章:http://t.csdnimg.cn/a1kCL 01背包的变种:. - 力扣(LeetCode) 给你一个 只包含正整数 的 非空 数组 nums 。请你判断是否可以将这个数组分割成两个子集,使得两个子集的元素和相等。 简化一…...

写作类AI推荐(二)

本章要介绍的写作AI如下: 火山写作 主要功能: AI智能创作:告诉 AI 你想写什么,立即生成你理想中的文章AI智能改写:选中段落句子,可提升表达、修改语气、扩写、总结、缩写等文章内容优化:根据全文…...

分寝室(20分)(JAVA)

目录 题目描述 输入格式: 输出格式: 输入样例 1: 输出样例 1: 输入样例 2: 输出样例 2: 题解: 题目描述 学校新建了宿舍楼,共有 n 间寝室。等待分配的学生中,有女…...

Spring 源码调试问题 ( List.of(“bin“, “build“, “out“); )

Spring 源码调试问题 文章目录 Spring 源码调试问题一、问题描述二、解决方案 一、问题描述 错误&#xff1a;springframework\buildSrc\src\main\java\org\springframework\build\CheckstyleConventions.java:68: 错误: 找不到符号 List<String> buildFolders List.of…...

Centos7安装RTL8111网卡驱动

方法一&#xff1a; // 安装pciutils # yum install -y pciutils // 查看pci设备信息 # lspci | grep -i Ethernet 09:00.0 Ethernet controller: Realtek Semiconductor Co., Ltd. RTL8111/8168/8411 PCI Express Gigabit Ethernet Controller (rev 03) // 上面看到是Re…...

吉时利KEITHLEY2460数字源表

181/2461/8938产品概述&#xff1a; Keithley 2460 高电流源表源测量单元 (SMU) 将先进的触摸、测试和发明技术带到您的指尖。Keithley 2460 将创新的图形用户界面 (GUI) 与电容式触摸屏技术相结合&#xff0c;使测试变得直观并最大限度地缩短学习曲线&#xff0c;从而帮助工程…...

数据库原理(含思维导图)

数据库原理笔记&#xff0c;html与md笔记已上传 1.绪论 发展历程 记住数据怎么保存&#xff0c;谁保存数据&#xff0c;共享性如何&#xff0c;独立性如何 人工管理阶段 数据不保存应用程序管理数据数据不共享数据不具有独立性 文件系统阶段 数据可以长期保存文件系统管…...

数据结构(六)——图

六、图 6.1 图的基本概念 图的定义 图&#xff1a;图G由顶点集V和边集E组成&#xff0c;记为G (V, E)&#xff0c;其中V(G)表示图G中顶点的有限非空集&#xff1b;E(G) 表示图G中顶点之间的关系&#xff08;边&#xff09;集合。若V {v1, v2, … , vn}&#xff0c;则用|V|…...

Android-AR眼镜屏幕显示

Android-AR眼镜 前提&#xff1a;Android手持设备 需要具备DP高清口 1、创建Presentation&#xff08;双屏异显&#xff09; public class MyPresentation extends Presentation {private PreviewSingleBinding binding;private ScanActivity activity;public MyPresentatio…...

蓝桥集训之货币系统

蓝桥集训之货币系统 核心思想&#xff1a;背包 #include <iostream>#include <cstring>#include <algorithm>using namespace std;const int N 30,M 10010;typedef long long LL;LL f[M];int w[N];int n,m;int main(){cin>>n>>m;for(int i1;i&…...

基于微信小程序的校园服务平台设计与实现(程序+论文)

本文以校园服务平台为研究对象&#xff0c;首先分析了当前校园服务平台的研究现状&#xff0c;阐述了本系统设计的意义和背景&#xff0c;运用微信小程序开发工具和云开发技术&#xff0c;研究和设计了一个校园服务平台&#xff0c;以满足学生在校园生活中的多样化需求。通过引…...

QT+Opencv+yolov5实现监测

功能说明&#xff1a;使用QTOpencvyolov5实现监测 仓库链接&#xff1a;https://gitee.com/wangyoujie11/qt_yolov5.git git本仓库到本地 一、环境配置 1.opencv配置 将OpenCV-MinGW-Build-OpenCV-4.5.2-x64文件夹放在自己的一个目录下&#xff0c;如我的路径&#xff1a; …...

【Python-Docx库】Word与Python的完美结合

【Python-Docx库】Word与Python的完美结合 今天给大家分享Python处理Word的第三方库&#xff1a;Python-Docx。 什么是Python-Docx&#xff1f; Python-Docx是用于创建和更新Microsoft Word&#xff08;.docx&#xff09;文件的Python库。 日常需要经常处理Word文档&#xf…...

吴恩达深度学习笔记:浅层神经网络(Shallow neural networks)3.6-3.8

目录 第一门课&#xff1a;神经网络和深度学习 (Neural Networks and Deep Learning)第三周&#xff1a;浅层神经网络(Shallow neural networks)3.6 激活函数&#xff08;Activation functions&#xff09;3.7 为什么需要非线性激活函数&#xff1f;&#xff08;why need a non…...

盘点最适合做剧场版的国漫,最后一部有望成为巅峰

最近《完美世界》动画官宣首部剧场版&#xff0c;主要讲述石昊和火灵儿的故事。这个消息一出&#xff0c;引发了很多漫迷的讨论&#xff0c;其实现在已经有好几部国漫做过剧场版&#xff0c;还有是观众一致希望未来会出剧场版的。那么究竟是哪些国漫呢&#xff0c;下面就一起来…...

Altium Designer许可需求分析

在电子设计的世界中&#xff0c;Altium Designer已成为设计师们的得力助手。然而&#xff0c;如何进行有效的许可需求分析&#xff0c;以确保软件的高效使用和企业的可持续发展&#xff1f;本文将带您了解如何进行Altium Designer的许可需求分析&#xff0c;让您在设计的道路上…...

[c++]类和对象常见题目详解

本专栏内容为&#xff1a;C学习专栏&#xff0c;分为初阶和进阶两部分。 通过本专栏的深入学习&#xff0c;你可以了解并掌握C。 &#x1f493;博主csdn个人主页&#xff1a;小小unicorn ⏩专栏分类&#xff1a;C &#x1f69a;代码仓库&#xff1a;小小unicorn的代码仓库&…...

v7上线首周,93%老用户没发现的隐藏指令——高阶提示工程实战手册,含12个未公开参数调用语法

更多请点击&#xff1a; https://intelliparadigm.com 第一章&#xff1a;Midjourney v7核心架构升级与隐性能力图谱 多模态融合推理引擎重构 Midjourney v7 引入了基于分层注意力对齐&#xff08;Hierarchical Attention Alignment, HAA&#xff09;的新型生成主干&#xff…...

AI应用开发平台RiserFlow实战:从架构解析到智能客服构建

1. 项目概述&#xff1a;从“RiserFlow”看现代AI应用开发范式的演进最近在GitHub上看到一个挺有意思的项目&#xff0c;叫riserlabs/riserflow。光看这个名字&#xff0c;可能有点摸不着头脑&#xff0c;但如果你点进去&#xff0c;会发现它其实指向一个更具体的产品&#xff…...

BetterGI自动化工具:每天为原神玩家节省2小时

BetterGI自动化工具&#xff1a;每天为原神玩家节省2小时 【免费下载链接】better-genshin-impact &#x1f4e6;BetterGI 更好的原神 - 自动拾取 | 自动剧情 | 全自动钓鱼(AI) | 全自动七圣召唤 | 自动伐木 | 自动刷本 | 自动采集/挖矿/锄地 | 一条龙 | 全连音游 | 自动烹饪 …...

在Node.js后端服务中集成Taotoken调用大模型指南

&#x1f680; 告别海外账号与网络限制&#xff01;稳定直连全球优质大模型&#xff0c;限时半价接入中。 &#x1f449; 点击领取海量免费额度 在Node.js后端服务中集成Taotoken调用大模型指南 将大模型能力集成到后端服务是现代应用开发的常见需求。Taotoken平台提供了OpenA…...

ClawGuard:为Clawdbot AI智能体打造的安全监控与熔断防护系统

1. 项目概述&#xff1a;ClawGuard 是什么&#xff0c;以及为什么你需要它如果你正在使用或开发基于 Clawdbot 框架的 AI 智能体&#xff0c;那么“安全”和“可控”这两个词&#xff0c;大概率已经在你脑海里盘旋过无数次了。我接触过不少团队&#xff0c;从最初的兴奋于 AI 智…...

AI时代Clean Code新标准(DeepSeek R1实测验证版):92.7%可维护性提升背后的11个关键断点

更多请点击&#xff1a; https://intelliparadigm.com 第一章&#xff1a;AI时代Clean Code范式迁移的必然性 当大语言模型能自动生成函数、修复漏洞、甚至重构整包逻辑时&#xff0c;“可读性优先”的传统Clean Code原则正遭遇结构性挑战。人类开发者编写的代码不再唯一面向…...

星际软件开发:为火星殖民地编写第一批代码

一、引言&#xff1a;当测试左移到大气层之外2041年&#xff0c;第一批火星殖民者即将启程。他们携带的不仅是氧气和速食&#xff0c;还有一座预装在密封舱里的微型数据中心。在这片红色荒漠上&#xff0c;代码将比氧气更早醒来——生命维持系统的控制逻辑、通讯中继的协议栈、…...

AI科技热点日报 | 2026年5月12日

文章目录AI科技热点日报 | 2026年5月12日一、 行业标准与规范&#xff1a;AI终端迈入“标准化”时代二、 智能体&#xff08;Agent&#xff09;与具身智能&#xff1a;从云端走向实战三、 算力与基础设施&#xff1a;产业链的深度重构四、 产业融合与应用探索&#xff1a;AI fo…...

3步精通UE4SS游戏Mod开发:从注入到实战完全指南

3步精通UE4SS游戏Mod开发&#xff1a;从注入到实战完全指南 【免费下载链接】RE-UE4SS Injectable LUA scripting system, SDK generator, live property editor and other dumping utilities for UE4/5 games 项目地址: https://gitcode.com/gh_mirrors/re/RE-UE4SS UE…...

航拍UAV电力电缆巡检检测数据集_数据集第10027期

航拍UAV电力电缆巡检检测数据集_数据集第10027期 项目简介 面向无人机电力巡检场景的开源目标检测数据集&#xff0c;聚焦电力电缆识别任务&#xff0c;可用于电力线检测、植被与电力线安全距离监测等场景&#xff0c;助力电力巡检智能化。 数据集核心信息 数据规模&#xff1a…...