当前位置: 首页 > news >正文

李宏毅机器学习HW1: COVID-19 Cases Prediction

Kaggle数据集和提交链接

特征选择(主要修改地方)

在sample code的基础上主要修改了Select_feat选择特征函数。
首先,因为数据集中的第一列是id,先在raw_x_trainraw_x_validraw_x_test中都去掉这一列。其次,使用SelectKBest根据特征与目标之间的相关性来选择10个最重要的特征。

def select_feat(train_data, valid_data, test_data, select_all = True):# labely_train = train_data[:, -1]y_valid = valid_data[:, -1]# feature# 第一列是idraw_x_train = train_data[:, 1:-1]raw_x_valid = valid_data[:, 1:-1]raw_x_test = test_data[:, 1:]if select_all:feat_idx = list(range(raw_x_train.shape[1]))# 后续修改这里选择合适的特征else:# 使用SelectKBest根据特征与目标之间的相关性来选择k个最重要的特征selector = SelectKBest(f_regression, k=10) #如果是回归问题可以使用f_regression,如果是分类问题可以使用f_classifselector.fit(raw_x_train, y_train)feat_idx = selector.get_support(indices=True) # 获取选中的特征的索引return raw_x_train[:, feat_idx], raw_x_valid[:, feat_idx], raw_x_test[:, feat_idx], y_train, y_valid

蓝色为原始选择全部特征,红色为上述代码选择10个特征的结果,可以发现loss大大降低。
在这里插入图片描述
两次提交的分数如下,有很大的提升
在这里插入图片描述参考作业划分的标准,已达到了strong baseline。
在这里插入图片描述

完整代码

完整代码如下:

import math
import numpy as np
import pandas as pd
import os
import csv
# 进度条
from tqdm import tqdm
# Pytorch
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
# tensorboard
from torch.utils.tensorboard import SummaryWriter
# SelectKBest 用于特征选择
from sklearn.feature_selection import SelectKBest, f_regression# 设置随机种子,保证实验的可重复性
def same_seed(seed):# 设置 PyTorch 后端的 cuDNN 为确定性模式,保证每次运行结果一致torch.backends.cudnn.deterministic = True# 禁用 cuDNN 的自动优化,保证每次运行结果一致torch.backends.cudnn.benchmark = Falsenp.random.seed(seed)torch.manual_seed(seed)if torch.cuda.is_available():torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)# 划分数据集
# 原数据中只有训练集和测试集,从训练集中划分出验证集
def train_valid_split(data_set, valid_ratio, seed):valid_data_size = int(len(data_set) * valid_ratio)train_data_size = len(data_set) - valid_data_sizetrain_data, valid_data = random_split(data_set, [train_data_size, valid_data_size], generator=torch.Generator().manual_seed(seed))return np.array(train_data), np.array(valid_data)# 选择特征,默认是选择全部的(117个)feature来做训练
# 后续可选择合适的特征来优化模型
def select_feat(train_data, valid_data, test_data, select_all = True):# labely_train = train_data[:, -1]y_valid = valid_data[:, -1]# feature# 第一列是idraw_x_train = train_data[:, 1:-1]raw_x_valid = valid_data[:, 1:-1]raw_x_test = test_data[:, 1:]if select_all:feat_idx = list(range(raw_x_train.shape[1]))# 后续修改这里选择合适的特征else:# 使用SelectKBest根据特征与目标之间的相关性来选择k个最重要的特征selector = SelectKBest(f_regression, k=10) #如果是回归问题可以使用f_regression,如果是分类问题可以使用f_classifselector.fit(raw_x_train, y_train)feat_idx = selector.get_support(indices=True) # 获取选中的特征的索引return raw_x_train[:, feat_idx], raw_x_valid[:, feat_idx], raw_x_test[:, feat_idx], y_train, y_valid# 数据集类
class COVID19Dataset(Dataset):def __init__(self, features, targets=None):# 做预测,不用label,只用featuresif targets is None:self.targets = targets  # none# 做训练,有labelelse:self.targets = torch.FloatTensor(targets)self.features = torch.FloatTensor(features)def __getitem__(self, idx):if self.targets is None:return self.features[idx]else:return self.features[idx], self.targets[idx]def __len__(self):return len(self.features)# 神经网络模型
class My_Model(nn.Module):def __init__(self, input_dim):super(My_Model, self).__init__()self.layers = nn.Sequential(nn.Linear(input_dim, 16),nn.ReLU(),nn.Linear(16, 8),nn.ReLU(),nn.Linear(8, 1))def forward(self, x):x = self.layers(x)x = x.squeeze(1)  # (B, 1) -> (B)return x# 参数设置
device = 'cuda' if torch.cuda.is_available() else 'cpu'
config = {'seed': 5201314,'select_all': True,'valid_ratio': 0.2,'n_epochs': 3000,'batch_size': 256,'learning_rate': 1e-5,'early_stop': 400,  # 如果连续400个epoch验证集的loss都没有下降,就提前停止训练'save_path': './models/model.ckpt'
}# 训练过程
def trainer(train_loader, valid_loader, model, config, device):criterion = nn.MSELoss(reduce='mean')  # 默认为mean,计算所有元素的均值作为最终的损失值。# momentum 可以帮助优化器在陡峭的曲面上更快地找到最优解。# 例如,momentum=0.9 表示每次更新时,90%的更新量来自于上一次的更新方向,10%来自于当前的梯度方向。这样可以使得优化过程更加平滑和快速。optimizer = torch.optim.SGD(model.parameters(), lr = config['learning_rate'], momentum=0.9)writer = SummaryWriter()if not os.path.isdir('./models'):os.makedirs('./models')n_epochs = config['n_epochs']best_loss = math.inf  #初始值设置为无穷大step = 0early_stop_count = 0for epoch in range(n_epochs):model.train()loss_record = []# train_loader 被封装以可视化训练进度。position=0表示进度条在最上面,leave=True表示训练完成后不清除进度条train_pbar = tqdm(train_loader, position=0, leave=True)"""训练循环"""for x, y in train_pbar:optimizer.zero_grad()  # 梯度清零x, y = x.to(device), y.to(device)pred = model(x)  # 前向传播loss = criterion(pred, y)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数step += 1loss_record.append(loss.item())# 显示训练过程train_pbar.set_description(f'Epoch {epoch + 1}/{n_epochs}')train_pbar.set_postfix({'loss': loss.item()})mean_train_loss = sum(loss_record) / len(loss_record)writer.add_scalar('Loss/train', mean_train_loss, step)"""验证循环"""model.eval()loss_record = []for x, y in valid_loader:x, y = x.to(device), y.to(device)# 验证集不需要计算梯度with torch.no_grad():pred = model(x)loss = criterion(pred, y)loss_record.append(loss.item())mean_valid_loss = sum(loss_record) / len(loss_record)print(f'Epoch {epoch + 1}/{n_epochs}, Train loss: {mean_train_loss: .4f}, Valid loss: {mean_valid_loss: .4f}')writer.add_scalar('Loss/valid', mean_valid_loss, step)# 根据验证集的损失值保存最佳模型。if mean_valid_loss < best_loss:best_loss = mean_valid_losstorch.save(model.state_dict(), config['save_path'])  # Save your best modelprint('Saving model with loss {:.3f}...'.format(best_loss))early_stop_count = 0else:early_stop_count += 1if early_stop_count >= config['early_stop']:print("\n Model is not improving, so we halt the training process.")return"""准备工作"""
# 设置随机种子
same_seed(config['seed'])
# 读取数据
train_data = pd.read_csv('./covid.train.csv').values
test_data = pd.read_csv('./covid.test.csv').values
# 划分数据集
train_data, valid_data = train_valid_split(train_data, config['valid_ratio'], config['seed'])
print(f"""train data size: {len(train_data)}, valid data size: {len(valid_data)}, test data size: {len(test_data)}""")
# 选择特征
x_train, x_valid, x_test, y_train, y_valid = select_feat(train_data, valid_data, test_data, config['select_all'])
print(f"""The number of features: {x_train.shape[1]}""")
# 构造数据集
train_dataset = COVID19Dataset(x_train, y_train)
valid_dataset = COVID19Dataset(x_valid, y_valid)
test_dataset = COVID19Dataset(x_test)
# dataloader
train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, pin_memory=True)# 开始训练
model = My_Model(input_dim=x_train.shape[1]).to(device)
trainer(train_loader, valid_loader, model, config, device)# 预测
def predict(test_loader, model, device):model.eval()preds = []for x in tqdm(test_loader):x = x.to(device)with torch.no_grad():pred = model(x)preds.append(pred.detach().cpu())preds = torch.cat(preds, dim=0).numpy()return predsdef save_pred(preds, file):with open(file, 'w') as fp:writer = csv.writer(fp)writer.writerow(['id', 'tested_positive'])for i, p in enumerate(preds):writer.writerow([i, p])# 预测并保存结果
model = My_Model(input_dim=x_test.shape[1]).to(device)
model.load_state_dict(torch.load(config['save_path']))
preds = predict(test_loader, model, device)
save_pred(preds, './pred.csv')

相关文章:

李宏毅机器学习HW1: COVID-19 Cases Prediction

Kaggle数据集和提交链接 特征选择&#xff08;主要修改地方&#xff09; 在sample code的基础上主要修改了Select_feat选择特征函数。 首先&#xff0c;因为数据集中的第一列是id&#xff0c;先在raw_x_train&#xff0c;raw_x_valid&#xff0c;raw_x_test中都去掉这一列。其…...

MySQL下载安装DataGrip可视化工具

目录 WinMySQL下载安装步骤MySQL配置添加环境变量 Mac下载安装配置环境变量 DataGrip可视化工具以Win为例了。Mac忘记截图了。步骤都一样 Win MySQL下载 官网&#xff1a; https://www.mysql.com/ 直接进下载界面&#xff1a; https://downloads.mysql.com/archives/installe…...

多平台下Informatica在医疗数据抽取中的应用

一、引言 1.医疗数据抽取与 Informatica 概述 1.1 医疗数据的特点与来源 1.1.1 数据特点 医疗数据具有显著的多样性特点。从数据类型来看&#xff0c;涵盖了结构化数据&#xff0c;如患者的基本信息、检验检查结果等&#xff0c;这些数据通常以表格形式存储&#xff0c;便于…...

用公网服务器实现内网穿透

首先需要一个公网服务器 下载frp 搜索github下载到frp&#xff0c;服务端frps/客户端frpc。。下载的时候要注意自己本地内网机的cpu版本和服务端cpu架构 我的电脑是mac M1PRO版本 下载的是&#xff1a;darwinarm64 比如 服务端一般是Linux&#xff08;Intel 64位CPU&#xf…...

为什么mysql更改表结构时,varchar超过255会锁表

在 MySQL 中&#xff0c;当修改表结构并将 VARCHAR 字段的长度设置为超过 255 时&#xff0c;可能会出现锁表的情况。这与 MySQL 的存储引擎&#xff08;主要是 InnoDB&#xff09;以及表的底层存储方式相关。 原因分析 行格式变化 InnoDB 存储引擎支持多种行格式&#xff08;…...

ASP.NET Core中 JWT 实现无感刷新Token

在 Web 应用开发中&#xff0c;用户登录状态的管理至关重要。为了避免用户频繁遇到登录过期的问题&#xff0c;我们可以通过实现 JWT&#xff08;JSON Web Token&#xff09;刷新机制来提升用户体验 推荐: 使用 Refresh Token&#xff08;双 Token 机制&#xff09; 1. 生成和…...

函数(函数的概念、库函数、自定义函数、形参和实参、return语句、数组做函数参数、嵌套调用和链式访问、函数的声明和定义、static和extern)

一、函数的概念 •C语⾔中的函数&#xff1a;⼀个完成某项特定的任务的⼀⼩段代码 •函数又被翻译为子函数&#xff08;更准确&#xff09; •在C语⾔中我们⼀般会⻅到两类函数&#xff1a;库函数 ⾃定义函数 二、库函数 1 .标准库和头文件 •C语⾔的国际标准ANSIC规定了⼀…...

物联网在烟草行业的应用

物联网技术在烟草行业的应用 物联网技术在烟草行业的应用主要体现在以下几个方面&#xff1a; 智能制造 &#xff1a;物联网技术可以实现对生产过程中的关键参数进行实时监测&#xff0c;确保产品的质量稳定可靠。同时&#xff0c;通过对设备的远程维护和故障诊断&#xff0c;…...

第6章:Python TDD实例变量私有化探索

写在前面 这本书是我们老板推荐过的&#xff0c;我在《价值心法》的推荐书单里也看到了它。用了一段时间 Cursor 软件后&#xff0c;我突然思考&#xff0c;对于测试开发工程师来说&#xff0c;什么才更有价值呢&#xff1f;如何让 AI 工具更好地辅助自己写代码&#xff0c;或许…...

Java操作Excel导入导出——POI、Hutool、EasyExcel

目录 一、POI导入导出 1.数据库导出为Excel文件 2.将Excel文件导入到数据库中 二、Hutool导入导出 1.数据库导出为Excel文件——属性名是列名 2.数据库导出为Excel文件——列名起别名 3.从Excel文件导入数据到数据库——属性名是列名 4.从Excel文件导入数据到数据库…...

BUUCTF_Web([GYCTF2020]Ezsqli)

1.输入1 &#xff0c;正常回显。 2.输入1 &#xff0c;报错false&#xff0c;为字符型注入&#xff0c;单引号闭合。 原因&#xff1a; https://mp.csdn.net/mp_blog/creation/editor/145170456 3.尝试查询字段&#xff0c;回显位置&#xff0c;数据库&#xff0c;都是这个。…...

微软宣布Win11 24H2进入新阶段!设备将自动下载更新

快科技1月19日消息&#xff0c;微软于1月16日更新了支持文档&#xff0c;宣布Windows 11 24H2进入新阶段。 24H2更新于2024年10月1日发布&#xff0c;此前为可选升级&#xff0c;如今微软开始在兼容的Windows 11设备上自动下载并安装24H2版本。 微软表示&#xff1a;“运行Wi…...

SpringBoot:解决前后端请求跨域问题(详细教程)

文章目录 一、前言二、解决方式 2.1 使用 CrossOrigin 注解&#xff08;简单方便&#xff0c;适用于单个或少量接口&#xff09;2.2 全局配置跨域&#xff08;适用于整个项目中大量接口都需要跨域的情况&#xff09;2.3 使用过滤器来处理跨域&#xff08;更底层的实现方式&…...

Android-V lmkd 中的那些属性值

源码基于&#xff1a;Android V 相关博文&#xff1a; Android lmkd 机制详解&#xff08;一&#xff09; Android lmkd 机制详解&#xff08;二&#xff09; Android lmkd 机制从R到T 1. 汇总 属性名说明默认值 ro.lmk.debug 启动 lmkd 的debug 模式&#xff0c;会打印一…...

PageHelper快速使用

依赖 <!--分页插件PageHelper--> <dependency><groupId>com.github.pagehelper</groupId><artifactId>pagehelper-spring-boot-starter</artifactId><version>1.4.7</version> </dependency>示例 /** * 封装分页结果…...

图像处理基础(3):均值滤波器及其变种

均值滤波器可以归为低通滤波器&#xff0c;是一种线性滤波器&#xff0c;其输出为邻域模板内的像素的简单平均值&#xff0c;主要用于图像的模糊和降噪。 均值滤波器的概念非常的直观&#xff0c;使用滤波器窗口内的像素的平均灰度值代替图像中的像素值&#xff0c;这样的结果就…...

力扣刷题心得_JAVA

数学 > 数组 > 链表 > 字符串 > 哈希表 > 双指针 > 递归 > 栈 > 队列 > 树 //一般力扣中传入的参数和新建的对象作为返回值,都不列入空间复杂度中 //但是面试的时候要和面试官商量好,灵活定义空间复杂度 //当然最好是就在传入的对象作为返回值,(在原…...

音乐播放器实现:前端HTML,CSS,JavaScript综合大项目

音乐播放器实现:前端HTML&#xff0c;CSS&#xff0c;JavaScript综合大项目 项目概述项目视图效果一、侧边栏相关代码&#xff08;一&#xff09;HTML代码&#xff08;二&#xff09;css代码 二、登录页面&#xff08;一&#xff09;HTML代码&#xff08;二&#xff09;css代码…...

Unity编辑器缩放设置

Unity默认界面UI字体太小了&#xff0c;可以设置一下缩放 打开首选项&#xff0c; UI Scaling 设置成125%或者更大 &#xff0c;然后重启...

ChatGPT大模型极简应用开发-CH1-初识 GPT-4 和 ChatGPT

文章目录 1.1 LLM 概述1.1.1 语言模型和NLP基础1.1.2 Transformer及在LLM中的作用1.1.3 解密 GPT 模型的标记化和预测步骤 1.2 GPT 模型简史&#xff1a;从 GPT-1 到 GPT-41.2.1 GPT11.2.2 GPT21.2.3 GPT-31.2.4 从 GPT-3 到 InstructGPT1.2.5 GPT-3.5、Codex 和 ChatGPT1.2.6 …...

【Linux】shell脚本忽略错误继续执行

在 shell 脚本中&#xff0c;可以使用 set -e 命令来设置脚本在遇到错误时退出执行。如果你希望脚本忽略错误并继续执行&#xff0c;可以在脚本开头添加 set e 命令来取消该设置。 举例1 #!/bin/bash# 取消 set -e 的设置 set e# 执行命令&#xff0c;并忽略错误 rm somefile…...

label-studio的使用教程(导入本地路径)

文章目录 1. 准备环境2. 脚本启动2.1 Windows2.2 Linux 3. 安装label-studio机器学习后端3.1 pip安装(推荐)3.2 GitHub仓库安装 4. 后端配置4.1 yolo环境4.2 引入后端模型4.3 修改脚本4.4 启动后端 5. 标注工程5.1 创建工程5.2 配置图片路径5.3 配置工程类型标签5.4 配置模型5.…...

今日科技热点速览

&#x1f525; 今日科技热点速览 &#x1f3ae; 任天堂Switch 2 正式发售 任天堂新一代游戏主机 Switch 2 今日正式上线发售&#xff0c;主打更强图形性能与沉浸式体验&#xff0c;支持多模态交互&#xff0c;受到全球玩家热捧 。 &#x1f916; 人工智能持续突破 DeepSeek-R1&…...

Android 之 kotlin 语言学习笔记三(Kotlin-Java 互操作)

参考官方文档&#xff1a;https://developer.android.google.cn/kotlin/interop?hlzh-cn 一、Java&#xff08;供 Kotlin 使用&#xff09; 1、不得使用硬关键字 不要使用 Kotlin 的任何硬关键字作为方法的名称 或字段。允许使用 Kotlin 的软关键字、修饰符关键字和特殊标识…...

有限自动机到正规文法转换器v1.0

1 项目简介 这是一个功能强大的有限自动机&#xff08;Finite Automaton, FA&#xff09;到正规文法&#xff08;Regular Grammar&#xff09;转换器&#xff0c;它配备了一个直观且完整的图形用户界面&#xff0c;使用户能够轻松地进行操作和观察。该程序基于编译原理中的经典…...

docker 部署发现spring.profiles.active 问题

报错&#xff1a; org.springframework.boot.context.config.InvalidConfigDataPropertyException: Property spring.profiles.active imported from location class path resource [application-test.yml] is invalid in a profile specific resource [origin: class path re…...

Java线上CPU飙高问题排查全指南

一、引言 在Java应用的线上运行环境中&#xff0c;CPU飙高是一个常见且棘手的性能问题。当系统出现CPU飙高时&#xff0c;通常会导致应用响应缓慢&#xff0c;甚至服务不可用&#xff0c;严重影响用户体验和业务运行。因此&#xff0c;掌握一套科学有效的CPU飙高问题排查方法&…...

【JVM面试篇】高频八股汇总——类加载和类加载器

目录 1. 讲一下类加载过程&#xff1f; 2. Java创建对象的过程&#xff1f; 3. 对象的生命周期&#xff1f; 4. 类加载器有哪些&#xff1f; 5. 双亲委派模型的作用&#xff08;好处&#xff09;&#xff1f; 6. 讲一下类的加载和双亲委派原则&#xff1f; 7. 双亲委派模…...

「全栈技术解析」推客小程序系统开发:从架构设计到裂变增长的完整解决方案

在移动互联网营销竞争白热化的当下&#xff0c;推客小程序系统凭借其裂变传播、精准营销等特性&#xff0c;成为企业抢占市场的利器。本文将深度解析推客小程序系统开发的核心技术与实现路径&#xff0c;助力开发者打造具有市场竞争力的营销工具。​ 一、系统核心功能架构&…...

深入浅出WebGL:在浏览器中解锁3D世界的魔法钥匙

WebGL&#xff1a;在浏览器中解锁3D世界的魔法钥匙 引言&#xff1a;网页的边界正在消失 在数字化浪潮的推动下&#xff0c;网页早已不再是静态信息的展示窗口。如今&#xff0c;我们可以在浏览器中体验逼真的3D游戏、交互式数据可视化、虚拟实验室&#xff0c;甚至沉浸式的V…...