当前位置: 首页 > news >正文

用来生成二维矩阵的dcgan

有大量二维矩阵作为样本,为连续数据。数据具有空间连续性,因此用卷积网络,通过dcgan生成二维矩阵。因为是连续变量,因此损失采用nn.MSELoss()。

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from DemDataset import create_netCDF_Dem_trainLoader
import torchvision
from torch.utils.tensorboard import SummaryWriterbatch_size=16
#load data
dataloader = create_netCDF_Dem_trainLoader(batch_size)# Generator with Conv2D structure
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.model = nn.Sequential(nn.ConvTranspose2d(100, 512, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(512),nn.ReLU(),nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(512),nn.ReLU(),nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(256),nn.ReLU(),nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(128),nn.ReLU(),nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(64),nn.ReLU(),nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(32),nn.ReLU(),nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),nn.Tanh())def forward(self, z):img = self.model(z)return img# Discriminator with Conv2D structure
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),nn.LeakyReLU(0.2),nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),nn.LeakyReLU(0.2),nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),nn.LeakyReLU(0.2),nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),nn.LeakyReLU(0.2),nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),nn.LeakyReLU(0.2),nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1),nn.LeakyReLU(0.2),nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=1),)def forward(self, img):validity = self.model(img)return validity# Initialize GAN components
generator = Generator()
discriminator = Discriminator()# Define loss function and optimizers
criterion = nn.MSELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)
discriminator.to(device)writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0# Training loop
num_epochs = 200
for epoch in range(num_epochs):for batch_idx, real_data in enumerate(dataloader):real_data = real_data.to(device)# Train Discriminatoroptimizer_D.zero_grad()real_labels = torch.ones(real_data.size(0), 1).to(device)fake_labels = torch.zeros(real_data.size(0), 1).to(device)z = torch.randn(real_data.size(0), 100, 1, 1).to(device)fake_data = generator(z)real_pred = discriminator(real_data)fake_pred = discriminator(fake_data.detach())d_loss_real = criterion(real_pred, real_labels)d_loss_fake = criterion(fake_pred, fake_labels)d_loss = d_loss_real + d_loss_faked_loss.backward()optimizer_D.step()# Train Generatoroptimizer_G.zero_grad()z = torch.randn(real_data.size(0), 100, 1, 1).to(device)fake_data = generator(z)fake_pred = discriminator(fake_data)g_loss = criterion(fake_pred, real_labels)g_loss.backward()optimizer_G.step()# Print progressif batch_idx % 100 == 0:print(f"[Epoch {epoch}/{num_epochs}] [Batch {batch_idx}/{len(dataloader)}] [D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")with torch.no_grad():img_grid_real = torchvision.utils.make_grid(fake_data#, normalize=True,)img_grid_fake = torchvision.utils.make_grid(real_data#, normalize=True)writer_fake.add_image("fake_img", img_grid_fake, global_step=step)writer_real.add_image("real_img", img_grid_real, global_step=step)step += 1# After training, you can generate a 2D array by sampling from the generator
z = torch.randn(1, 100, 1, 1).to(device)
generated_array = generator(z)

相关文章:

用来生成二维矩阵的dcgan

有大量二维矩阵作为样本,为连续数据。数据具有空间连续性,因此用卷积网络,通过dcgan生成二维矩阵。因为是连续变量,因此损失采用nn.MSELoss()。 import torch import torch.nn as nn import torch.optim as optim import numpy a…...

免费的国产数据集成平台推荐

在如今的数字化时代下,企业内部的数据无疑是重要资产之一。随着数据源的多样性和数量剧增,如何有效地收集、整合、存储、管理和分析数据变得至关重要。为了解决这些常见痛点,数据集成平台成为了现代企业不可或缺的一部分。 数据集成是现代数…...

【yolov8系列】yolov8的目标检测、实例分割、关节点估计的原理解析

1 YOLO时间线 这里简单列下yolo的发展时间线,对每个版本的提出有个时间概念。 2 yolov8 的简介 工程链接:https://github.com/ultralytics/ultralytics 2.1 yolov8的特点 采用了anchor free方式,去除了先验设置可能不佳带来的影响借鉴Genera…...

5256C 5G终端综合测试仪

01 5256C 5G终端综合测试仪 产品综述: 5256C 5G终端综合测试仪主要用于5G终端、基带芯片的研发、生产、校准、检测、认证和教学等领域。该仪表具备5G信号发送功能、5G信号功率特性、解调特性和频谱特性分析功能,支持5G终端的产线高速校准及终端发射机…...

Springboot Actuator 环境搭建踩坑

JMX和Springboot Actuator JMX是Java Management Extensions,它是一个Java平台的管理和监控接口。 为什么要搞JMX呢?因为在所有的应用程序中,对运行中的程序进行监控都是非常重要的,Java应用程序也不例外。我们肯定希望知道Java…...

Vue-3.3ESLint

ESLint代码规范 代码规范:一套写代码的约定规则。 JavaScript Standard Style规范说明https://standardjs.com/rules-zhcn.html 代码规范错误 如果你的代码不符合standard的要求,ESlint会跳出来提醒。 比如:在mian.js中随意做一些改动&a…...

STROBE-MR

Welcome to the STROBE-MR website! About: STROBE-MR stands for “Strengthening the Reporting of Observational Studies in Epidemiology using Mendelian Randomization”. Inspired by the original STROBE checklist, the STROBE-MR guidelines were developed to ass…...

Hive安装配置 - 内嵌模式

文章目录 一、Hive运行模式二、安装配置内嵌模式Hive(一)下载hive安装包(二)上传hive安装包(三)解压缩hive安装包(四)配置hive环境变量(五)关联Hadoop&#x…...

html中登录按钮添加回车键登录

原文链接有3种方法&#xff0c;其它2中不会弄&#xff0c;第二种方法成功&#xff0c;下面详细说说 原html的登录部分是 <button class"btn btn-success btn-block waves-effect waves-light" id"button" >登入</button> 在该html中增加 &…...

PCL 空间两平面交线计算

PCL 空间两平面交线计算 std::vector<float> LineInPlanes(std::vector<double> para1, std::vector<double> para2) {std::vector<float...

交替合并字符串

题目要求 给你两个字符串 word1 和 word2 。请你从 word1 开始&#xff0c;通过交替添加字母来合并字符串。如果一个字符串比另一个字符串长&#xff0c;就将多出来的字母追加到合并后字符串的末尾。 返回 合并后的字符串 。 示例 示例 1&#xff1a; 输入&#xff1a;word1 …...

Linux考试复习整理

文章目录 Linux考试整理一.选择题1.用户的密码现象放置在哪个文件夹&#xff1f;2.删除文件或目录的命令是&#xff1f;3.显示一个文件最后几行的命令是&#xff1f;4.删除一个用户并同时删除用户的主目录5.Linux配置文件一般放在什么目录&#xff1f;6.某文件的组外成员的权限…...

基于geojson-vt和canvas的高性能出图

概述 本文介绍基于geojson-vt和canvas&#xff0c;实现node端高性能出图。 效果 实现 1. canvas绘图 import { createCanvas } from canvasconst tileSize 256; const canvas createCanvas(tileSize, tileSize) const ctx canvas.getContext(2d)2. 处理geojson const g…...

CTF是黑客大赛?新手如何入门CTF?

CTF是啥 CTF 是 Capture The Flag 的简称&#xff0c;中文咱们叫夺旗赛&#xff0c;其本意是西方的一种传统运动。在比赛上两军会互相争夺旗帜&#xff0c;当有一方的旗帜已被敌军夺取&#xff0c;就代表了那一方的战败。在信息安全领域的 CTF 是说&#xff0c;通过各种攻击手…...

电脑开不了机用U盘重装系统Win10教程

如果我们遇到了电脑开不起机的问题&#xff0c;这给我们的正常使用带来了很大的影响。这时候我们可以借助U盘重装系统的方法&#xff0c;轻松应对这一问题。下面小编给大家详细介绍关于用U盘给开不机的电脑重装Win10系统的教程步骤&#xff0c;操作后用户就能正常使用电脑了。 …...

四叉堆在GO中的应用-定时任务timer

堆作为必须掌握的数据结构之一&#xff0c;在众多场景中也得到了广泛的应用。 比较典型的&#xff0c;如java中的优先队列PriorityQueue、算法中的TOP-K问题、最短路径Dijkstra算法等&#xff0c;在这些经典应用中堆都担任着灵魂般的角色。 理论基础 binary heap 再一起回忆…...

Flow深入浅出系列之使用Kotlin Flow自动刷新Android数据的策略

Flow深入浅出系列之在ViewModels中使用Kotlin FlowsFlow深入浅出系列之更聪明的分享 Kotlin FlowsFlow深入浅出系列之使用Kotlin Flow自动刷新Android数据的策略 Flow深入浅出系列之使用Kotlin Flow自动刷新Android数据的策略 讨论在Android应用程序中使用Kotlin Flow高效加载…...

AC修炼计划(AtCoder Regular Contest 165)

传送门&#xff1a;AtCoder Regular Contest 165 - AtCoder 本次习题参考了樱雪猫大佬的题解&#xff0c;大佬的题解传送门如下&#xff1a;Atcoder Regular Contest 165 - 樱雪喵 - 博客园 (cnblogs.com) A - Sum equals LCM 第一题不算特别难 B - Sliding Window Sort 2 对…...

【Express】登录鉴权 JWT

JWT&#xff08;JSON Web Token&#xff09;是一种用于实现身份验证和授权的开放标准。它是一种基于JSON的安全传输数据的方式&#xff0c;由三部分组成&#xff1a;头部、载荷和签名。 使用jsonwebtoken模块&#xff0c;你可以在Node.js应用程序中轻松生成和验证JWT。以下是j…...

【微服务 SpringCloud】实用篇 · Ribbon负载均衡

微服务&#xff08;4&#xff09; 文章目录 微服务&#xff08;4&#xff09;1. 负载均衡原理2. 源码跟踪1&#xff09;LoadBalancerIntercepor2&#xff09;LoadBalancerClient3&#xff09;负载均衡策略IRule4&#xff09;总结 3. 负载均衡策略3.1 负载均衡策略3.2 自定义负载…...

zabbix-proxy代理服务器配置

下载zabbix源 rpm -Uvh https://repo.zabbix.com/zabbix/5.0/rhel/7/x86_64/zabbix-release-5.0-1.el7.noarch.rpm 安装 yum -y install zabbix-proxy-mysql zabbix_get 查看相关文件路径 rpm -ql zabbix-proxy-mysql 创建数据库 mysq -uroot -proot mysql> create database…...

【python零基础入门学习】python进阶篇之OOP - 面向对象的程序设计

本站以分享各种运维经验和运维所需要的技能为主 《python零基础入门》&#xff1a;python零基础入门学习 《python运维脚本》&#xff1a; python运维脚本实践 《shell》&#xff1a;shell学习 《terraform》持续更新中&#xff1a;terraform_Aws学习零基础入门到最佳实战 《k8…...

中国xx集团信息技术工程师面试

进入面试间&#xff0c;坐着三位面试官&#xff0c;压力扑面而来&#xff0c;三位面试官先做了自我介绍&#xff0c;介绍了一下面试的流程后才开始面试。 一、自我介绍 不多说。 二、看你学过数据挖掘这门课&#xff0c;能简单介绍一下有哪些章节&#xff0c;学了些什么&…...

Jmeter接口自动化测试 —— Jmeter下载安装及入门

jmeter简介 Apache JMeter是Apache组织开发的基于Java的压力测试工具。用于对软件做压力测试&#xff0c;它最初被设计用于Web应用测试&#xff0c;但后来扩展到其他测试领域。 下载 下载地址&#xff1a;Apache JMeter - Download Apache JMeter 安装 由于Jmeter是基于Java的…...

ARM 学习笔记2 初识Cortex-M33与STM32G4

入门 ARM Cortex-M系列处理器的差异与联系&#xff1a;【ARM Cortex-M 系列 1 – Cortex-M0, M3, M4, M7, M33 差异】两本书籍的英文版和中文版 Definitive Guide to Arm Cortex-M23 and Cortex-M33 Processors Arm Cortex-M23和Cortex-M33微处理器权威指南ST的介绍页 Arm Cor…...

vue中使用coordtransform 互相转换坐标系

官方网站&#xff1a;https://www.npmjs.com/package/coordtransform 在使用高德sdk时&#xff0c;其返回的坐标在地图上显示时有几百米的偏移&#xff0c;这是由于高德用的是 火星坐标&#xff08;GCJ02&#xff09;&#xff0c;而不是wgs84坐标。为了消除偏移&#xff0c;将G…...

双线性插值详解

双线性插值的原理网上资料非常多,本文重点介绍双线性插值实现的两种方式: 角对齐(coner_align = True) 和 边对齐(coner_align = False)。两种不能的方式下去实现双线性插值,目标图像中的每个像素点,它是如何计算取值的,本文会通过原理结合代码的方式将实现细节讲清楚。 1…...

C++ “”

&加上有时候会加速 如果想该对象跟着函数变化一定要加“&” 在题目函数里面定义的 例如 vector<vector<bool>> visited(grid.size(),vector<bool>(grid[0].size(),false)); 如果自己定义的新void dfs&#xff08;vector<vector<bool>>…...

计算机三级有必要考吗?计算机三级有哪些科目?

在大学期间&#xff0c;计算机等级考试是一门很火热的考试&#xff0c;很多小伙伴通过二级考试以后在究竟是报考三级还是四级之间徘徊&#xff0c;下面肉丸子就来给大家分析一下&#xff0c;究竟有没有必要考计算机三级考试&#xff0c;以及计算机三级考试的科目有哪些&#xf…...

6.5 Elasticsearch(五)Spring Data Elasticsearch - 增删改查API

文章目录 1.Spring Data Elasticsearch2.案例准备2.1 在 Elasticsearch 中创建 students 索引2.2 案例测试说明 3.创建项目3.1 新建工程3.2 新建 springboot module&#xff0c;添加 spring data elasticsearch 依赖3.3 pom.xml 文件3.4 application.yml 配置 4.Student 实体类…...