当前位置: 首页 > news >正文

用来生成二维矩阵的dcgan

有大量二维矩阵作为样本,为连续数据。数据具有空间连续性,因此用卷积网络,通过dcgan生成二维矩阵。因为是连续变量,因此损失采用nn.MSELoss()。

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from DemDataset import create_netCDF_Dem_trainLoader
import torchvision
from torch.utils.tensorboard import SummaryWriterbatch_size=16
#load data
dataloader = create_netCDF_Dem_trainLoader(batch_size)# Generator with Conv2D structure
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.model = nn.Sequential(nn.ConvTranspose2d(100, 512, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(512),nn.ReLU(),nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(512),nn.ReLU(),nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(256),nn.ReLU(),nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(128),nn.ReLU(),nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(64),nn.ReLU(),nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(32),nn.ReLU(),nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),nn.Tanh())def forward(self, z):img = self.model(z)return img# Discriminator with Conv2D structure
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),nn.LeakyReLU(0.2),nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),nn.LeakyReLU(0.2),nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),nn.LeakyReLU(0.2),nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),nn.LeakyReLU(0.2),nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),nn.LeakyReLU(0.2),nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1),nn.LeakyReLU(0.2),nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=1),)def forward(self, img):validity = self.model(img)return validity# Initialize GAN components
generator = Generator()
discriminator = Discriminator()# Define loss function and optimizers
criterion = nn.MSELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)
discriminator.to(device)writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0# Training loop
num_epochs = 200
for epoch in range(num_epochs):for batch_idx, real_data in enumerate(dataloader):real_data = real_data.to(device)# Train Discriminatoroptimizer_D.zero_grad()real_labels = torch.ones(real_data.size(0), 1).to(device)fake_labels = torch.zeros(real_data.size(0), 1).to(device)z = torch.randn(real_data.size(0), 100, 1, 1).to(device)fake_data = generator(z)real_pred = discriminator(real_data)fake_pred = discriminator(fake_data.detach())d_loss_real = criterion(real_pred, real_labels)d_loss_fake = criterion(fake_pred, fake_labels)d_loss = d_loss_real + d_loss_faked_loss.backward()optimizer_D.step()# Train Generatoroptimizer_G.zero_grad()z = torch.randn(real_data.size(0), 100, 1, 1).to(device)fake_data = generator(z)fake_pred = discriminator(fake_data)g_loss = criterion(fake_pred, real_labels)g_loss.backward()optimizer_G.step()# Print progressif batch_idx % 100 == 0:print(f"[Epoch {epoch}/{num_epochs}] [Batch {batch_idx}/{len(dataloader)}] [D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")with torch.no_grad():img_grid_real = torchvision.utils.make_grid(fake_data#, normalize=True,)img_grid_fake = torchvision.utils.make_grid(real_data#, normalize=True)writer_fake.add_image("fake_img", img_grid_fake, global_step=step)writer_real.add_image("real_img", img_grid_real, global_step=step)step += 1# After training, you can generate a 2D array by sampling from the generator
z = torch.randn(1, 100, 1, 1).to(device)
generated_array = generator(z)

相关文章:

用来生成二维矩阵的dcgan

有大量二维矩阵作为样本,为连续数据。数据具有空间连续性,因此用卷积网络,通过dcgan生成二维矩阵。因为是连续变量,因此损失采用nn.MSELoss()。 import torch import torch.nn as nn import torch.optim as optim import numpy a…...

免费的国产数据集成平台推荐

在如今的数字化时代下,企业内部的数据无疑是重要资产之一。随着数据源的多样性和数量剧增,如何有效地收集、整合、存储、管理和分析数据变得至关重要。为了解决这些常见痛点,数据集成平台成为了现代企业不可或缺的一部分。 数据集成是现代数…...

【yolov8系列】yolov8的目标检测、实例分割、关节点估计的原理解析

1 YOLO时间线 这里简单列下yolo的发展时间线,对每个版本的提出有个时间概念。 2 yolov8 的简介 工程链接:https://github.com/ultralytics/ultralytics 2.1 yolov8的特点 采用了anchor free方式,去除了先验设置可能不佳带来的影响借鉴Genera…...

5256C 5G终端综合测试仪

01 5256C 5G终端综合测试仪 产品综述: 5256C 5G终端综合测试仪主要用于5G终端、基带芯片的研发、生产、校准、检测、认证和教学等领域。该仪表具备5G信号发送功能、5G信号功率特性、解调特性和频谱特性分析功能,支持5G终端的产线高速校准及终端发射机…...

Springboot Actuator 环境搭建踩坑

JMX和Springboot Actuator JMX是Java Management Extensions,它是一个Java平台的管理和监控接口。 为什么要搞JMX呢?因为在所有的应用程序中,对运行中的程序进行监控都是非常重要的,Java应用程序也不例外。我们肯定希望知道Java…...

Vue-3.3ESLint

ESLint代码规范 代码规范:一套写代码的约定规则。 JavaScript Standard Style规范说明https://standardjs.com/rules-zhcn.html 代码规范错误 如果你的代码不符合standard的要求,ESlint会跳出来提醒。 比如:在mian.js中随意做一些改动&a…...

STROBE-MR

Welcome to the STROBE-MR website! About: STROBE-MR stands for “Strengthening the Reporting of Observational Studies in Epidemiology using Mendelian Randomization”. Inspired by the original STROBE checklist, the STROBE-MR guidelines were developed to ass…...

Hive安装配置 - 内嵌模式

文章目录 一、Hive运行模式二、安装配置内嵌模式Hive(一)下载hive安装包(二)上传hive安装包(三)解压缩hive安装包(四)配置hive环境变量(五)关联Hadoop&#x…...

html中登录按钮添加回车键登录

原文链接有3种方法&#xff0c;其它2中不会弄&#xff0c;第二种方法成功&#xff0c;下面详细说说 原html的登录部分是 <button class"btn btn-success btn-block waves-effect waves-light" id"button" >登入</button> 在该html中增加 &…...

PCL 空间两平面交线计算

PCL 空间两平面交线计算 std::vector<float> LineInPlanes(std::vector<double> para1, std::vector<double> para2) {std::vector<float...

交替合并字符串

题目要求 给你两个字符串 word1 和 word2 。请你从 word1 开始&#xff0c;通过交替添加字母来合并字符串。如果一个字符串比另一个字符串长&#xff0c;就将多出来的字母追加到合并后字符串的末尾。 返回 合并后的字符串 。 示例 示例 1&#xff1a; 输入&#xff1a;word1 …...

Linux考试复习整理

文章目录 Linux考试整理一.选择题1.用户的密码现象放置在哪个文件夹&#xff1f;2.删除文件或目录的命令是&#xff1f;3.显示一个文件最后几行的命令是&#xff1f;4.删除一个用户并同时删除用户的主目录5.Linux配置文件一般放在什么目录&#xff1f;6.某文件的组外成员的权限…...

基于geojson-vt和canvas的高性能出图

概述 本文介绍基于geojson-vt和canvas&#xff0c;实现node端高性能出图。 效果 实现 1. canvas绘图 import { createCanvas } from canvasconst tileSize 256; const canvas createCanvas(tileSize, tileSize) const ctx canvas.getContext(2d)2. 处理geojson const g…...

CTF是黑客大赛?新手如何入门CTF?

CTF是啥 CTF 是 Capture The Flag 的简称&#xff0c;中文咱们叫夺旗赛&#xff0c;其本意是西方的一种传统运动。在比赛上两军会互相争夺旗帜&#xff0c;当有一方的旗帜已被敌军夺取&#xff0c;就代表了那一方的战败。在信息安全领域的 CTF 是说&#xff0c;通过各种攻击手…...

电脑开不了机用U盘重装系统Win10教程

如果我们遇到了电脑开不起机的问题&#xff0c;这给我们的正常使用带来了很大的影响。这时候我们可以借助U盘重装系统的方法&#xff0c;轻松应对这一问题。下面小编给大家详细介绍关于用U盘给开不机的电脑重装Win10系统的教程步骤&#xff0c;操作后用户就能正常使用电脑了。 …...

四叉堆在GO中的应用-定时任务timer

堆作为必须掌握的数据结构之一&#xff0c;在众多场景中也得到了广泛的应用。 比较典型的&#xff0c;如java中的优先队列PriorityQueue、算法中的TOP-K问题、最短路径Dijkstra算法等&#xff0c;在这些经典应用中堆都担任着灵魂般的角色。 理论基础 binary heap 再一起回忆…...

Flow深入浅出系列之使用Kotlin Flow自动刷新Android数据的策略

Flow深入浅出系列之在ViewModels中使用Kotlin FlowsFlow深入浅出系列之更聪明的分享 Kotlin FlowsFlow深入浅出系列之使用Kotlin Flow自动刷新Android数据的策略 Flow深入浅出系列之使用Kotlin Flow自动刷新Android数据的策略 讨论在Android应用程序中使用Kotlin Flow高效加载…...

AC修炼计划(AtCoder Regular Contest 165)

传送门&#xff1a;AtCoder Regular Contest 165 - AtCoder 本次习题参考了樱雪猫大佬的题解&#xff0c;大佬的题解传送门如下&#xff1a;Atcoder Regular Contest 165 - 樱雪喵 - 博客园 (cnblogs.com) A - Sum equals LCM 第一题不算特别难 B - Sliding Window Sort 2 对…...

【Express】登录鉴权 JWT

JWT&#xff08;JSON Web Token&#xff09;是一种用于实现身份验证和授权的开放标准。它是一种基于JSON的安全传输数据的方式&#xff0c;由三部分组成&#xff1a;头部、载荷和签名。 使用jsonwebtoken模块&#xff0c;你可以在Node.js应用程序中轻松生成和验证JWT。以下是j…...

【微服务 SpringCloud】实用篇 · Ribbon负载均衡

微服务&#xff08;4&#xff09; 文章目录 微服务&#xff08;4&#xff09;1. 负载均衡原理2. 源码跟踪1&#xff09;LoadBalancerIntercepor2&#xff09;LoadBalancerClient3&#xff09;负载均衡策略IRule4&#xff09;总结 3. 负载均衡策略3.1 负载均衡策略3.2 自定义负载…...

Go 网关模式:让业务逻辑和外部服务“保持距离“的艺术

&#x1f3ac; 场景小剧场 想象一下&#xff1a;你的电商系统要接支付功能。如果直接在 order 包里写 stripe.Charge()&#xff0c;明天老板说"换支付宝"&#xff0c;你就要满世界改代码 &#x1f62b; 网关模式就是给业务逻辑装个"万能插座"&#xff1a;不…...

DVWA-Chinese安全实践指南:从环境搭建到漏洞攻防

DVWA-Chinese安全实践指南&#xff1a;从环境搭建到漏洞攻防 【免费下载链接】DVWA-Chinese DVWA全汉化版本 项目地址: https://gitcode.com/gh_mirrors/dv/DVWA-Chinese 价值定位&#xff1a;为什么选择DVWA-Chinese作为安全学习平台 合法可控的漏洞实验场 Web安全学…...

League-Toolkit:英雄联盟客户端集成工具包的全方位应用指南

League-Toolkit&#xff1a;英雄联盟客户端集成工具包的全方位应用指南 【免费下载链接】League-Toolkit An all-in-one toolkit for LeagueClient. Gathering power &#x1f680;. 项目地址: https://gitcode.com/gh_mirrors/le/League-Toolkit 一、游戏场景中的实际挑…...

音频的爬虫

1.前提准备需要在终端中下载requests模块 --- 终端在软件的左下角&#xff0c;下方图案例下载的语法&#xff1a;pip install requests&#xff08;1&#xff09;下载成功会报出的结果&#xff0c;如下图所示&#xff1a;&#xff08;2&#xff09;下载失败会报出的结果&#…...

专业级GTA5辅助工具:YimMenu全维度安全防护与功能增强指南

专业级GTA5辅助工具&#xff1a;YimMenu全维度安全防护与功能增强指南 【免费下载链接】YimMenu YimMenu, a GTA V menu protecting against a wide ranges of the public crashes and improving the overall experience. 项目地址: https://gitcode.com/GitHub_Trending/yi/…...

城通网盘下载速度慢?试试ctfileGet,让你畅享本地高速解析体验

城通网盘下载速度慢&#xff1f;试试ctfileGet&#xff0c;让你畅享本地高速解析体验 【免费下载链接】ctfileGet 获取城通网盘一次性直连地址 项目地址: https://gitcode.com/gh_mirrors/ct/ctfileGet 在数字化办公与学习中&#xff0c;网盘已成为文件传输的重要工具。…...

Feather生态系统探索:从R包到Python包装器的完整技术栈

Feather生态系统探索&#xff1a;从R包到Python包装器的完整技术栈 【免费下载链接】feather wesm/feather: 是一个用于在 Python 和 R 之间传输数据的轻量级数据格式库。适合对数据科学和数据分析有兴趣的人&#xff0c;特别是需要在 Python 和 R 之间进行数据交换的人。特点是…...

SenseVoice-Small ONNX模型效果惊艳展示:中英粤日韩五语种同步识别样例

SenseVoice-Small ONNX模型效果惊艳展示&#xff1a;中英粤日韩五语种同步识别样例 今天&#xff0c;我想带大家看一个让我眼前一亮的语音识别模型——SenseVoice-Small的ONNX版本。它最吸引我的地方&#xff0c;是能同时识别中文、英文、粤语、日语和韩语&#xff0c;而且速度…...

你的QQ空间记忆会消失吗?GetQzonehistory终极备份方案让你完整珍藏青春印记

你的QQ空间记忆会消失吗&#xff1f;GetQzonehistory终极备份方案让你完整珍藏青春印记 【免费下载链接】GetQzonehistory 获取QQ空间发布的历史说说 项目地址: https://gitcode.com/GitHub_Trending/ge/GetQzonehistory 在数字时代&#xff0c;我们的青春记忆大多散落在…...

Qwen3.5-2B企业降本案例:用2B模型替代8B,GPU成本降低57%实录

Qwen3.5-2B企业降本案例&#xff1a;用2B模型替代8B&#xff0c;GPU成本降低57%实录 1. 轻量化模型带来的成本革命 在AI应用大规模落地的今天&#xff0c;模型部署成本已成为企业最关注的痛点之一。我们团队近期完成了一个典型案例&#xff1a;用Qwen3.5-2B模型成功替代原有8…...