当前位置: 首页 > article >正文

LaSt-ViT:Vision Transformers Need More Than Registers(CVPR 2026)

前言尽管 Vision Transformers (ViTs) 在图像分类等领域取得了巨大成功但其内部机制仍存在诸多未解之谜。近年来的研究发现在需要密集特征的下游任务中ViTs 表现出多种令人困惑的伪影 (Artifacts)这些问题普遍存在于不同的训练范式中全监督 (Fully-supervised)存在明显的“注意力缺陷”生成的特征图无法有效聚焦于物体主体导致局部特征提取能力受限。文本监督 (Text-supervised)稠密特征与文本提示的对齐精度不佳难以在像素级别精准匹配语义信息。自监督 (Self-supervised)模型中出现“高范数令牌” (High-norm tokens)成为干扰项严重影响对目标物体的精确定位。这些问题背后是否有一个共同的原因呢今天我们要学习的是来自CVPR2026的前沿研究LaSt ViT(LazyStrike ViT)。论文Vision Transformers Need More Than Registers代码https://github.com/ChengShiest/LAST-ViT惰性聚合为了解释这些现象这篇工作提出了“惰性聚合假说”。该假说认为ViT的伪影源于一种“偷懒”行为“在粗粒度语义监督和全局注意力机制的共同作用下ViT倾向于利用语义上无关的背景补丁作为捷径 (Shortcut) 来编码全局语义而非专注于前景目标。”粗粒度语义监督(Coarse-grained Supervision)指模型仅获取图片级的类别标签缺乏对各图像块(Patch)的精确监督。模型可以通过任何足以区分不同类别的特征完成任务而无需依赖最具代表性的前景特征。而全局注意力机制允许信息在所有图像块之间自由、高效地流动。若某些背景特征与类别高度相关模型会快速将前景语义“扩散”到背景上形成训练捷径。于是模型发现与其费力地去学习前景特征不如利用无处不在的背景作为捷径来完成分类任务。这就是ViT伪影的根源。假说验证这里定义了两个指标Patch Score 每个补丁特征与CLS令牌的余弦相似度。高分值意味着该补丁与图像整体语义高度相关是衡量局部特征重要性的核心指标。Point-in-Box (PiB) 最高分补丁落在前景标注框内的图像比例。PiB数值越高代表模型越能准确地将“全局语义重心”与“视觉前景”对齐。Patch Score衡量每个图像块与整体语义的相关性PiB衡量最高分块是否落在前景。令人惊讶的是背景块的Patch Score分数反而更高。更有趣的是当我们把分数最高的一半块大多是背景遮住模型的分类准确率竟然没怎么变。这直接证明了ViT确实是在“偷懒”靠背景来完成任务。研究还给出了三个证据懒惰从训练初期就存在在追踪训练全流程的过程中发现了一个极具反差性的现象ViT模型的分类准确率曲线表现完美随着训练推进稳步攀升。然而衡量其定位前景能力的指标——PiB分数却始终在约42%的低位徘徊远低于基于卷积的ResNet模型。这一结果有力地证明模型利用背景信息走捷径的“伪影”并非训练后期才产生的副作用而是从训练初期就已形成并贯穿全过程。模型在学习分类任务的同时从第一个epoch起就选择了更容易的背景特征而非学习真正的前景特征。粗粒度监督引导ViT偷懒分类任务的标签仅告知模型图中“有什么”未指明物体“在哪里”。由于背景 Patch 数量远多于前景模型为了快速收敛会利用背景线索而非学习物体特征。研究团队也做了实验把patch从16x16改到了28x28pib有所上升但准确率却有所下降。这表明了增大 Patch 尺寸以减少背景干扰结果显示定位能力PiB上升但分类准确率下降。这表明模型宁愿牺牲定位能力也要依赖背景“捷径”。全局注意力的影响为验证全局注意力是否会通过允许前景语义传播到背景区域而加剧其惰性聚合行为。为逐步限制长程依赖关系在不同层级将全局自注意力替换为基于窗口的注意力详细实验如下表所示实验说明随着全局注意力的限制PIB得分逐渐提高在所有层均采用窗口注意力时达到最高分数59.8。然而准确性相应下降到了63.9这表明尽管全局上下文对分类有益但它也会促进语义向背景区域扩散。LaSt-ViT为了从根本上解决“惰性聚合”问题LaSt-ViT (LazyStrike ViT)采用了一种简单而有效的基于频率感知的选择性聚合方案。彻底重构 CLS 令牌的聚合逻辑不再像传统 ViT 那样对所有补丁特征进行“无差别”聚合。 转而让 CLS 令牌只选择性地聚合来自前景补丁的有效特征将背景视为干扰并在聚合阶段予以过滤。利用深层网络特征图的通道维度变化差异通过频率分析实现前景与背景的分离前景语义均匀 → 特征变化小 →低频信号背景语义杂乱 → 特征变化大 →高频信号利用这个特性我们就能通过频率分析来筛选出特征稳定的patch。核心实现LazyStrike ViT的核心操作主要有三步1计算稳定性分数 (Stability Score)首先对每个补丁的特征向量在通道维度进行一维傅里叶变换应用高斯低通滤波后逆变换。滤波后与原始特征的比值即为稳定性分数分数越高代表语义越稳定越可能是前景。2通道级Top-K聚合 (Channel-wise Top-K Pooling)为每个通道选择稳定性分数最高的K个补丁对其特征做平均池化将结果整合到CLS令牌中从而聚合各通道最具信息量的局部特征。3投票计数筛选 (Vote Count)统计每个补丁在所有通道中被选中的次数次数越多代表补丁的重要性越高。通过这种投票机制进一步强化了对图像中前景区域的特征表征能力。通过这三步我们就能让CLS令牌精准地锚定前景。算法实现可以参考下面 LAST-ViT: Vision Transformers Need More Than Registers Core implementation of the frequency-domain token selection mechanism. Reference: https://arxiv.org/abs/2602.22394 import torch import torch.nn as nn from torchvision.models.vision_transformer import VisionTransformer class LASTViT(VisionTransformer): LAST-ViT replaces the standard CLS token with a frequency-domain token selection mechanism. Key idea: 1. Apply FFT Gaussian low-pass filter to patch tokens 2. Compute stability scores: diff original / |filtered - original| 3. Select the most stable patch token (top-k) 4. Average selected tokens as the new CLS token def __init__(self, image_size224, patch_size16, num_layers12, num_heads12, hidden_dim768, mlp_dim3072, num_classes1000, top_k1, sigmaNone, **kwargs): super().__init__( image_sizeimage_size, patch_sizepatch_size, num_layersnum_layers, num_headsnum_heads, hidden_dimhidden_dim, mlp_dimmlp_dim, **kwargs ) self.top_k top_k self.sigma sigma if sigma is not None else hidden_dim ** 0.5 self.cached_kernel None # Replace classification head (ViT_B_16_Weights has 1000 classes by default) self.heads nn.Linear(hidden_dim, num_classes) if num_classes ! 1000 else self.heads def gaussian_kernel_1d(self, kernel_size: int, sigma: float) - torch.Tensor: Create a 1D Gaussian kernel for frequency-domain filtering. positions torch.arange(-kernel_size // 2 1, kernel_size // 2 1).float() kernel torch.exp(-0.5 * (positions / sigma) ** 2) kernel kernel / kernel.max() return kernel def low_pass_filter(self, patch_tokens: torch.Tensor) - torch.Tensor: Apply frequency-domain low-pass filter (Gaussian in frequency domain). Args: patch_tokens: [B, N, D] patch embeddings Returns: Filtered patch tokens [B, N, D] original_dtype patch_tokens.dtype # Use float for FFT to avoid precision issues if patch_tokens.dtype in {torch.float16, torch.bfloat16}: patch_tokens patch_tokens.float() # Lazy initialization of Gaussian kernel if self.cached_kernel is None or self.cached_kernel.shape[-1] ! patch_tokens.shape[-1]: kernel self.gaussian_kernel_1d(patch_tokens.shape[-1], self.sigma) self.cached_kernel kernel.view(1, 1, -1).to(patch_tokens.device) # FFT-based filtering spectrum torch.fft.fft(patch_tokens, dim-1) spectrum torch.fft.fftshift(spectrum, dim-1) spectrum spectrum * self.cached_kernel spectrum torch.fft.ifftshift(spectrum, dim-1) filtered torch.fft.ifft(spectrum, dim-1).real return filtered.to(dtypeoriginal_dtype) def stability_score(self, original: torch.Tensor, filtered: torch.Tensor, eps: float 1e-6) - torch.Tensor: Compute token stability scores. Higher score more stable token (less affected by high-frequency removal) Formula from the paper: score original / |filtered - original| diff filtered - original return original / (torch.abs(diff) eps) def forward_features(self, x: torch.Tensor): Forward pass with token selection. Args: x: Input images [B, 3, H, W] Returns: logits: Classification logits [B, num_classes] cls_token: The aggregated CLS token [B, hidden_dim] (optional) # Standard ViT preprocessing and encoding x self._process_input(x) n x.shape[0] # Add class token batch_class_token self.class_token.expand(n, -1, -1) x torch.cat([batch_class_token, x], dim1) x self.encoder(x) # Patch tokens only (drop CLS token) patch_tokens x[:, 1:] # [B, N, D] # Apply low-pass filtering filtered_tokens self.low_pass_filter(patch_tokens) # Compute stability scores scores self.stability_score(patch_tokens, filtered_tokens) # Select top-k most stable tokens top_k min(self.top_k, patch_tokens.shape[1]) _, indices torch.topk(scores, ktop_k, dim1, largestTrue) # Gather selected tokens selected_tokens torch.gather(patch_tokens, 1, indices) # [B, k, D] # Average to form new CLS token cls_token torch.mean(selected_tokens, dim1) # [B, D] return cls_token, patch_tokens def forward(self, x: torch.Tensor): cls_token, _ self.forward_features(x) # Classification head logits self.heads(cls_token) return logits, cls_token def create_last_vit( pretrained_path: str None, top_k: int 1, num_classes: int 1000, device: str cuda if torch.cuda.is_available() else cpu ) - LASTViT: Create a LAST-ViT model with optional pretrained weights. Args: pretrained_path: Path to pretrained checkpoint (from GitHub releases) top_k: Number of tokens to select (k1 for standard LAST-ViT) num_classes: Number of output classes device: Device to load model on Returns: LASTViT model model LASTViT( image_size224, patch_size16, num_layers12, num_heads12, hidden_dim768, mlp_dim3072, num_classesnum_classes, top_ktop_k ) if pretrained_path: checkpoint torch.load(pretrained_path, map_locationcpu) # Handle different checkpoint formats if isinstance(checkpoint, dict): state_dict checkpoint.get(model, checkpoint) else: state_dict checkpoint # Remove common prefixes new_state_dict {} for key, value in state_dict.items(): new_key key for prefix in [module., model.]: if new_key.startswith(prefix): new_key new_key[len(prefix):] new_state_dict[new_key] value # Load weights missing, unexpected model.load_state_dict(new_state_dict, strictFalse) print(fLoaded pretrained weights from {pretrained_path}) if missing: print(f Missing keys: {missing[:5]}... if len(missing) 5 else f Missing keys: {missing}) if unexpected: print(f Unexpected keys: {unexpected[:5]}... if len( unexpected) 5 else f Unexpected keys: {unexpected}) model model.to(device) return model if __name__ __main__: # Test the model print(LAST-ViT Model Test) device cuda if torch.cuda.is_available() else cpu print(fDevice: {device}) # Create model model create_last_vit(ViT_190k.pth, top_k1) model.eval() print(fModel created with {sum(p.numel() for p in model.parameters()):,} parameters) # Test forward pass dummy_input torch.randn(2, 3, 224, 224).to(device) with torch.no_grad(): logits, cls_token model(dummy_input) features model.forward_features(dummy_input) print(fInput shape: {dummy_input.shape}) print(f Logits shape: {logits.shape}) print(f CLS token shape: {cls_token.shape}) print(fExpected: logits [2, 1000], cls_token [2, 768]) if isinstance(features, tuple): cls_feat, patch_feat features print(f\nFeature extraction:) print(f CLS feature: {cls_feat.shape}) print(f Patch feature: {patch_feat.shape}) print(\n✓ Model works correctly!)这里将特征提取与分类头进行了分离这部分说不定能作为特征提取。可下载标签监督的权重效果可视化这张图直观地展示了LaSt-ViT的效果。可以看到右边的可视化结果中ConvNet、ViT、LaSt-ViT高投票区域非常精准地覆盖了食物和动物这些前景物体。而标准 ViT 的 Patch Score 往往分散在复杂的背景区域中容易受噪声干扰而 LaSt-ViT 的投票区域则能准确聚焦于目标主体显著提升了识别的鲁棒性。优点总结LazyStrike完全不需要额外标注不改动ViT原有架构只在预训练阶段生效全监督、文本监督、自监督三种范式全部通用而且推理阶段没有任何额外开销真正做到了一次修改所有任务都能稳定涨点。实验结果12个基准全提效LaSt-ViT在三种主流训练范式全监督、文本监督、自监督、多种主流模型DeiT, ViT, CLIP, DINO以及12个不同的下游任务基准上进行了全面验证取得了显著且一致的性能提升。在全监督ViT上粗分割任务VOC12数据集从22.3%暴涨到32.8%相当于凭空多出了物体定位的能力在文本监督CLIP上语义分割VOC20从49.0直接冲到了75.0从普通能用的水平变成了SOTA级别在自监督DINO上不用Register也能实现无监督物体发现性能也能达到67.6。普通的ViT与加入了LazyStrike之后的ViT的PCA 可视化对比前则不能很好区分前景和背景而后者不仅能区分前景和背景还能 区分物体的各个部位。空间模式可视化 Visualize which token positions are most frequently selected by the LAST-ViT frequency-domain token selection mechanism across different k values. Use your own image folder instead of ImageNet. import os from collections import defaultdict import torch import torch.nn as nn import numpy as np import matplotlib matplotlib.use(Agg) # 保存图片用避免 TkAgg 报错 import matplotlib.pyplot as plt from PIL import Image from scipy import ndimage from torchvision.models.vision_transformer import VisionTransformer from torchvision.transforms import transforms as T from torch.utils.data import DataLoader, Dataset from tqdm import tqdm class CustomImageDataset(Dataset): Dataset for users own image folder. def __init__(self, root_dir, transformNone, max_samplesNone): self.root_dir root_dir self.transform transform self.image_paths [] valid_exts (.jpg, .jpeg, .png, .bmp, .webp) for dirpath, _, filenames in os.walk(root_dir): for name in filenames: if name.lower().endswith(valid_exts): self.image_paths.append(os.path.join(dirpath, name)) self.image_paths sorted(self.image_paths) if max_samples is not None: self.image_paths self.image_paths[:max_samples] if len(self.image_paths) 0: raise RuntimeError(fNo images found in: {root_dir}) print(fFound {len(self.image_paths)} images in: {root_dir}) def __len__(self): return len(self.image_paths) def __getitem__(self, idx): path self.image_paths[idx] img Image.open(path).convert(RGB) if self.transform is not None: img self.transform(img) label 0 return img, label, path class DenseViTWithTracking(VisionTransformer): ViT with token selection tracking for visualization. def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.cached_kernel None self.token_selection_counts defaultdict(lambda: defaultdict(int)) self.enable_tracking False self.k_values [1, 5, 10, 20] def gaussian_kernel_1d(self, kernel_size, sigma): device self.class_token.device x torch.arange( -kernel_size // 2 1, kernel_size // 2 1, devicedevice, dtypetorch.float32, ) kernel torch.exp(-0.5 * (x / sigma) ** 2) kernel kernel / torch.max(kernel) return kernel def compute_token_scores(self, images): Return: diff: [B, N], token frequency-domain difference score x_detach: [B, N, D], patch token embeddings x self._process_input(images) n x.shape[0] batch_class_token self.class_token.expand(n, -1, -1) x torch.cat([batch_class_token, x], dim1) x self.encoder(x) x_detach x[:, 1:] # [B, 196, 768] hidden_dim x_detach.shape[-1] if self.cached_kernel is None or self.cached_kernel.shape[-1] ! hidden_dim: self.cached_kernel ( self.gaussian_kernel_1d(hidden_dim, hidden_dim ** 0.5) .to(x.device) .unsqueeze(0) .unsqueeze(0) ) x_fft torch.fft.fft(x_detach, dim-1) x_fft torch.fft.fftshift(x_fft, dim-1) x_fft x_fft * self.cached_kernel.to(x.device) x_fft torch.fft.ifftshift(x_fft, dim-1) x_filtered torch.fft.ifft(x_fft, dim-1).real diff torch.norm(x_detach - x_filtered, dim-1) # [B, N] return diff, x_detach def forward(self, x: torch.Tensor): diff, x_detach self.compute_token_scores(x) if self.enable_tracking: for k in self.k_values: if k diff.shape[1]: _, indices torch.topk(diff, kk, dim1, largestTrue) for b in range(indices.shape[0]): selected_tokens indices[b].detach().cpu().numpy() for token_idx in selected_tokens: self.token_selection_counts[k][int(token_idx)] 1 _, indices torch.topk(diff, k1, dim1, largestTrue) selected_tokens torch.gather( x_detach, dim1, indexindices.unsqueeze(-1).expand(-1, -1, x_detach.shape[-1]), ) cls_token torch.mean(selected_tokens, dim1) return cls_token, None def get_topk_mask_for_image(self, image_tensor, k, device): Get current images selected token mask. image_tensor: [3, 224, 224] self.eval() with torch.no_grad(): if image_tensor.dim() 3: image_tensor image_tensor.unsqueeze(0) image_tensor image_tensor.to(device) diff, _ self.compute_token_scores(image_tensor) k min(k, diff.shape[1]) _, indices torch.topk(diff, kk, dim1, largestTrue) mask np.zeros(diff.shape[1], dtypenp.float32) for token_idx in indices[0].detach().cpu().numpy(): mask[int(token_idx)] 1.0 return mask def load_model_and_data(data_root, num_samples1000, batch_size32, checkpoint_pathNone): Load model and users custom image data. model DenseViTWithTracking( image_size224, patch_size16, num_layers12, num_heads12, hidden_dim768, mlp_dim3072, ) model.eval() if checkpoint_path and os.path.exists(checkpoint_path): print(fLoading pretrained weights: {checkpoint_path}) checkpoint torch.load(checkpoint_path, map_locationcpu) state_dict checkpoint if isinstance(checkpoint, dict): if model in checkpoint: state_dict checkpoint[model] elif state_dict in checkpoint: state_dict checkpoint[state_dict] new_state_dict {} for key, value in state_dict.items(): if key.startswith(model.): new_state_dict[key[6:]] value elif key.startswith(module.): new_state_dict[key[7:]] value else: new_state_dict[key] value try: model.load_state_dict(new_state_dict, strictTrue) print(Pretrained weights loaded successfully.) except Exception as e: print(fStrict loading failed: {e}) print(Trying partial loading...) model_dict model.state_dict() matched_dict { k: v for k, v in new_state_dict.items() if k in model_dict and model_dict[k].shape v.shape } model_dict.update(matched_dict) model.load_state_dict(model_dict) print(fPartial load succeeded: {len(matched_dict)} / {len(model_dict)} parameters matched.) else: raise FileNotFoundError( fCheckpoint not found: {checkpoint_path}\n f请确认 --checkpoint 路径正确。为了避免噪声这里不再使用随机权重。 ) transform T.Compose( [ T.Resize(256), T.CenterCrop(224), T.ToTensor(), T.Normalize( mean(0.485, 0.456, 0.406), std(0.229, 0.224, 0.225), ), ] ) dataset CustomImageDataset( root_dirdata_root, transformtransform, max_samplesnum_samples, ) dataloader DataLoader( dataset, batch_sizebatch_size, shuffleFalse, num_workers0, pin_memoryTrue, ) return model, dataloader def visualize_token_selection(token_counts_dict, num_tokens196, save_pathtoken_selection_heatmap.png): Visualize global token selection patterns for different k values. k_values sorted(token_counts_dict.keys()) num_k len(k_values) if num_k 0: print(No token selection counts found.) return cols min(3, num_k) rows (num_k cols - 1) // cols fig, axes plt.subplots(rows, cols, figsize(5 * cols, 5 * rows)) if num_k 1: axes [axes] else: axes np.array(axes).reshape(-1) grid_size int(np.sqrt(num_tokens)) for idx, k in enumerate(k_values): counts token_counts_dict[k] token_freq np.zeros(num_tokens, dtypenp.float32) for token_idx, count in counts.items(): if 0 token_idx num_tokens: token_freq[token_idx] count if token_freq.max() 0: token_freq_normalized token_freq / token_freq.max() else: token_freq_normalized token_freq token_grid token_freq_normalized.reshape(grid_size, grid_size) im axes[idx].imshow( token_grid, cmaphot, interpolationnearest, vmin0, vmax1, ) axes[idx].set_title( fk{k} Token Selection Frequency\nTotal selections: {int(token_freq.sum())}, fontsize12, fontweightbold, ) axes[idx].set_xlabel(Patch Column) axes[idx].set_ylabel(Patch Row) plt.colorbar(im, axaxes[idx], labelNormalized count) for idx in range(num_k, len(axes)): axes[idx].axis(off) plt.tight_layout() plt.savefig(save_path, dpi300, bbox_inchestight) plt.close() print(fGlobal heatmap saved to: {save_path}) stats_path save_path.replace(.png, _stats.png) fig, ax plt.subplots(figsize(14, 7)) top_n 30 for k in k_values: counts token_counts_dict[k] sorted_tokens sorted(counts.items(), keylambda x: x[1], reverseTrue)[:top_n] if len(sorted_tokens) 0: continue token_indices [t[0] for t in sorted_tokens] token_counts_vals [t[1] for t in sorted_tokens] ax.plot( token_indices, token_counts_vals, markero, labelfk{k}, linewidth2, markersize5, ) ax.set_xlabel(Token Index) ax.set_ylabel(Selection Count) ax.set_title(fTop {top_n} Most Selected Tokens) ax.legend() ax.grid(True, alpha0.3) plt.tight_layout() plt.savefig(stats_path, dpi300, bbox_inchestight) plt.close() print(fStatistics plot saved to: {stats_path}) def denormalize_image(tensor, mean(0.485, 0.456, 0.406), std(0.229, 0.224, 0.225)): mean torch.tensor(mean).view(3, 1, 1) std torch.tensor(std).view(3, 1, 1) return tensor * std mean def visualize_mask_on_image(image_tensor, mask, image_size224): Overlay selected patch mask on original image. img denormalize_image(image_tensor.clone()) img img.clamp(0, 1) img_np img.permute(1, 2, 0).cpu().numpy() grid_size int(np.sqrt(mask.size)) mask_2d mask.reshape(grid_size, grid_size) mask_img np.zeros((image_size, image_size), dtypebool) patch_h image_size // grid_size patch_w image_size // grid_size for i in range(grid_size): for j in range(grid_size): if mask_2d[i, j] 0: h_start i * patch_h h_end min((i 1) * patch_h, image_size) w_start j * patch_w w_end min((j 1) * patch_w, image_size) mask_img[h_start:h_end, w_start:w_end] True result img_np.copy() mask_3d mask_img[:, :, np.newaxis] red np.array([1.0, 0.25, 0.25]) fill_alpha 0.45 result result * (1 - fill_alpha * mask_3d) red * fill_alpha * mask_3d edges ndimage.sobel(mask_img.astype(float)) edge_mask np.abs(edges) 0.1 result[edge_mask, 0] 1.0 result[edge_mask, 1] 0.0 result[edge_mask, 2] 0.0 return np.clip(result, 0, 1) def save_sample_visualizations(model, images, paths, output_dir, device, start_index, max_save10): Save per-image token masks using current image top-k tokens. saved 0 num_tokens (224 // 16) ** 2 for img_idx in range(images.shape[0]): if start_index saved max_save: break sample_img images[img_idx].cpu() image_path paths[img_idx] fig, axes plt.subplots( 1, len(model.k_values), figsize(4 * len(model.k_values), 4), ) if len(model.k_values) 1: axes [axes] for idx, k in enumerate(model.k_values): mask model.get_topk_mask_for_image(sample_img, k, device) assert mask.size num_tokens img_with_mask visualize_mask_on_image(sample_img, mask) axes[idx].imshow(img_with_mask) axes[idx].set_title(fk{k}, fontsize12, fontweightbold) axes[idx].axis(off) basename os.path.splitext(os.path.basename(image_path))[0] save_name fsample_{start_index saved 1:03d}_{basename}.png save_path os.path.join(output_dir, save_name) plt.tight_layout() plt.savefig(save_path, dpi150, bbox_inchestight) plt.close() print(fSaved sample mask: {save_path}) saved 1 return saved def main(): import argparse parser argparse.ArgumentParser(descriptionVisualize LAST-ViT token selection patterns on custom images) parser.add_argument( --data-root, typestr, defaultrE:\PythonProject\YoloProject\data\test_coco8\images, helpPath to your own image folder, ) parser.add_argument( --num-samples, typeint, default100, helpNumber of images to use, ) parser.add_argument( --batch-size, typeint, default8, helpBatch size, ) parser.add_argument( --device, typestr, defaultcuda if torch.cuda.is_available() else cpu, helpDevice: cuda or cpu, ) parser.add_argument( --output-dir, typestr, default./visualize, helpOutput directory, ) parser.add_argument( --checkpoint, typestr, defaultrE:\PythonProject\LAST_ViT\ViT_190k.pth, helpPath to pretrained weights, ) parser.add_argument( --max-sample-vis, typeint, default10, helpNumber of individual images to visualize, ) args parser.parse_args() os.makedirs(args.output_dir, exist_okTrue) print( * 70) print(LAST-ViT Token Selection Visualization) print( * 70) print(fData root: {args.data_root}) print(fCheckpoint: {args.checkpoint}) print(fDevice: {args.device}) print(fNum samples: {args.num_samples}) print(fBatch size: {args.batch_size}) print(fOutput dir: {args.output_dir}) print( * 70) print(\nLoading model and custom data...) model, dataloader load_model_and_data( data_rootargs.data_root, num_samplesargs.num_samples, batch_sizeargs.batch_size, checkpoint_pathargs.checkpoint, ) model model.to(args.device) model.eval() model.enable_tracking True num_tokens (224 // 16) ** 2 print(f\nModel loaded.) print(fPatch tokens: {num_tokens}) print(fk values: {model.k_values}) print(\nRunning inference and tracking token selections...) saved_sample_count 0 with torch.no_grad(): for batch_idx, batch in enumerate(tqdm(dataloader, descProcessing)): images, labels, paths batch images images.to(args.device, non_blockingTrue) try: _ model(images) if saved_sample_count args.max_sample_vis: saved_now save_sample_visualizations( modelmodel, imagesimages, pathspaths, output_dirargs.output_dir, deviceargs.device, start_indexsaved_sample_count, max_saveargs.max_sample_vis, ) saved_sample_count saved_now except Exception as e: print(fError processing batch {batch_idx}: {e}) continue print(\nToken Selection Statistics:) print(- * 70) for k in sorted(model.token_selection_counts.keys()): counts model.token_selection_counts[k] total_selections sum(counts.values()) unique_tokens len(counts) print(fk{k:2d}: total selections{total_selections:8d}, unique tokens{unique_tokens:3d}) top_5 sorted(counts.items(), keylambda x: x[1], reverseTrue)[:5] print(f top 5 tokens: {top_5}) print(\nGenerating global statistics visualization...) save_path os.path.join(args.output_dir, token_selection_heatmap.png) visualize_token_selection( model.token_selection_counts, num_tokensnum_tokens, save_pathsave_path, ) print(\n * 70) print(Done.) print( * 70) if __name__ __main__: main()总结这项研究通过系统性分析深刻揭示了ViT因“惰性聚合”而依赖背景补丁作为语义捷径的核心成因并提出了一种名为LaSt-ViT的创新方法。该方法通过频率感知的选择性聚合策略从根源上消除了ViT的特征伪影最终实现了在多种训练范式和下游任务上的一致性性能提升。

相关文章:

LaSt-ViT:Vision Transformers Need More Than Registers(CVPR 2026)

前言 尽管 Vision Transformers (ViTs) 在图像分类等领域取得了巨大成功,但其内部机制仍存在诸多未解之谜。近年来的研究发现,在需要密集特征的下游任务中,ViTs 表现出多种令人困惑的伪影 (Artifacts),这些问题普遍存在于不同的训…...

CLeVeR:用多模态对比学习把“漏洞语义”从代码里挖出来

“现有自动化漏洞检测模型往往学习的是「整体功函数语义」,这会带入与漏洞无关的噪声,影响检测效果。CLeVeR提出用对比学习(contrastive learning)在代码与漏洞描述之间建立语义对齐,并通过Adapter、Representation Re…...

nstagram内容分级扩展后跨境品牌如何把握素材边界

数字围栏:内容分级时代,跨境品牌的素材合规之道当全球社交平台纷纷筑起内容分级的数字围栏,一场关于品牌表达边界的静默革命正在发生。对于跨境品牌而言,这不再仅仅是文化适配的课题,更是如何在日益复杂的数字监管环境…...

别再手写Word表格了!用poi-tl 1.12.0 + SpringBoot 3分钟搞定动态数据填充

3分钟极速上手:用poi-tl在SpringBoot中玩转Word表格动态填充 每次接到"导出Word报表"的需求就头皮发麻?还在用Apache POI逐行拼接表格单元格?上周团队新来的实习生花了整整两天调试一个动态表格导出功能,结果生成的文档…...

Taotoken的API Key管理与审计日志功能保障企业调用安全

Taotoken的API Key管理与审计日志功能保障企业调用安全 1. 企业级API Key管理 在Taotoken平台上,企业管理员可以创建多个API Key,并为每个Key分配不同的权限和使用限制。这一功能特别适合需要将大模型能力集成到多个项目或分配给不同团队的企业用户。 …...

对比直接使用厂商 API 通过聚合平台管理多模型成本更透明

通过聚合平台管理多模型成本更透明 1. 多厂商 API 的成本管理痛点 在同时使用多个大模型厂商的 API 时,成本管理往往面临诸多挑战。每个厂商都有独立的计费体系、账单周期和用量统计方式,导致开发者需要登录不同平台查看分散的数据。这种碎片化的管理方…...

数学老师都在用的GeoGebra 6,从下载到上手画图,10分钟搞定动态几何

GeoGebra 6:数学课堂的动态教学神器,10分钟从零到精彩演示 当抛物线在屏幕上随着参数的调整而优雅地舞动,当几何图形在拖动中展现出不变的性质,数学的魅力就这样直观地呈现在学生眼前。GeoGebra 6正是这样一款能让数学课堂活起来…...

别再死磕nmtui了!虚拟机里Linux网卡激活失败的3个真实原因与终极解法

虚拟机环境下Linux网卡激活失败的深度诊断与实战解决方案 当你第5次在虚拟机里敲下nmtui命令,屏幕依然弹出那个令人窒息的"Activation failed"错误时,该意识到问题可能远超出配置文件本身。作为常年与虚拟化环境打交道的技术顾问,我…...

Tidyverse 2.0报告自动化终极面试清单(23道题|11道代码实操|9道架构设计),仅剩最后200份PDF版解析可领

更多请点击: https://intelliparadigm.com 第一章:Tidyverse 2.0报告自动化核心演进与面试全景图 Tidyverse 2.0标志着R语言数据科学生态的一次结构性升级,其核心不再仅聚焦于语法一致性,而是深度整合报告生成、动态渲染与可复现…...

终极LaTeX公式转换指南:3秒将网页公式完美粘贴到Word

终极LaTeX公式转换指南:3秒将网页公式完美粘贴到Word 【免费下载链接】LaTeX2Word-Equation Copy LaTeX Equations as Word Equations, a Chrome Extension 项目地址: https://gitcode.com/gh_mirrors/la/LaTeX2Word-Equation 还在为学术论文写作时公式复制格…...

别再死记硬背Payload了!用DVWA靶场手把手教你理解SQL注入与XSS的底层原理

从DVWA靶场实战拆解Web安全核心原理&#xff1a;SQL注入与XSS的攻防博弈 当你第一次在DVWA靶场中输入admin or 11成功登录时&#xff0c;是否思考过为什么这个简单的字符串能绕过密码验证&#xff1f;当<img srcx onerroralert(1)>在页面上弹出警告框时&#xff0c;浏览器…...

三电平半桥LLC谐振变换器电路仿真研究:移相角度控制与DSP PWM生成方式探讨,输出电压优化...

三电平半桥LLC谐振变换器电路仿真 采用频率控制方式 引入一定的移相角度&#xff08;比较小&#xff09; 驱动信号采用CMPA CMPB方式产生 增计数模式&#xff08;参照DSP PWM生成&#xff09; 相比普通半桥LLC开关管电压应力小 输出电压闭环控制 输出特性好&#xff0c;几乎无超…...

Firefox老版本爱好者的自救指南:手动修改prefs.js与channel-prefs.js锁定版本

Firefox版本锁定终极指南&#xff1a;从配置文件到注册表的深度控制 你是否也遇到过这样的困扰&#xff1f;精心挑选的Firefox旧版本在不知不觉中被强制升级&#xff0c;熟悉的界面突然变得陌生&#xff0c;那些陪伴多年的插件一夜之间全部失效。对于依赖特定版本进行开发测试的…...

论mysql国盾shell-sfa犯罪行为集团下的分项工程及反向注入原理尐深度纳米算法下的鐌檵鄐鉎行为

SQL注入核心技术原理及纳米技术深度计算机算法机器应用函数技术的黑客用途是什么涵盖与控制原理**1. 概念澄清&#xff1a;不存在“纳米技术深度计算机算法”** * **SQL 注入**是一种针对**数据库软件层面**的网络攻击技术&#xff0c;利用的是代码逻辑漏洞。 * **纳米技术…...

VR视频转换终极指南:用VR-Reversal将3D视频智能转换为2D格式

VR视频转换终极指南&#xff1a;用VR-Reversal将3D视频智能转换为2D格式 【免费下载链接】VR-reversal VR-Reversal - Player for conversion of 3D video to 2D with optional saving of head tracking data and rendering out of 2D copies. 项目地址: https://gitcode.com…...

关于Vscode配置企业Git

1.获取账号信息①企业邮箱&#xff1a;xxxxxxxxxxx.com.cn②在邮箱里会有企业给你的git密码修改自己设置③打开Vscode下方终端旁边有一个加号&#xff0c;新建终端2.配置终端打开 VS Code&#xff0c;在顶部菜单栏点击 终端(Terminal) -> 新建终端(New Terminal)&#xff0c…...

思源宋体TTF版本兼容性与升级指南

思源宋体TTF版本兼容性与升级指南 【免费下载链接】source-han-serif-ttf Source Han Serif TTF 项目地址: https://gitcode.com/gh_mirrors/so/source-han-serif-ttf 版本兼容性矩阵 版本发布日期主要特性兼容性说明升级建议v1.0012021-10-15初始版本发布完全兼容所有…...

【2024信创落地硬核案例】:某政务终端从ARM切换至平头哥曳影1520,C驱动重写仅用11人日——附完整Makefile与Kconfig补丁包

更多请点击&#xff1a; https://intelliparadigm.com 第一章&#xff1a;国产化 RISC-V 芯片 C 语言驱动适配案例 随着平头哥、芯来科技、赛昉科技等厂商推出成熟 RISC-V SoC&#xff08;如 TH1520、Nuclei N/NX 系列、JH7110&#xff09;&#xff0c;国产嵌入式生态正加速构…...

为什么你的Tidyverse 2.0报告总在CI/CD中断?8大环境变量冲突真相,含可复用的docker-compose.yml模板

更多请点击&#xff1a; https://intelliparadigm.com 第一章&#xff1a;Tidyverse 2.0自动化数据报告的核心挑战与定位 Tidyverse 2.0 的发布标志着 R 生态在声明式数据处理与可重复报告生成方面迈入新阶段&#xff0c;但其自动化能力在真实生产环境中仍面临多重结构性挑战。…...

别再被线阻坑了!用开尔文四线法精准测量毫欧级电阻(附Multisim仿真步骤)

毫欧级电阻测量的终极方案&#xff1a;开尔文四线法全解析与Multisim实战 在硬件调试的微观世界里&#xff0c;毫欧级电阻的测量就像用普通尺子测量头发丝的直径——传统两线法的误差足以淹没真实信号。当某次电源模块异常发热的排查中&#xff0c;我反复测量MOSFET的导通电阻始…...

别急着把 autocast 全切成 bf16:RTX 3090 上把 GEMM、Conv2d 和 ResNet18 训练都跑完后,我的推荐顺序是这样

别急着把 autocast 全切成 bf16:RTX 3090 上把 GEMM、Conv2d 和 ResNet18 训练都跑完后,我的推荐顺序是这样 很多人把 bf16 当成“更稳的 fp16”,也有人一提消费级显卡就先下结论:bf16 肯定更慢,别折腾。我这次在一张 RTX 3090 上,把 4096x4096 的 GEMM、Conv2d 和 ResN…...

VSCode 2026协作权限体系曝光:细粒度文件级/行级/语义级锁定策略(含RBAC+SCIM集成方案)

更多请点击&#xff1a; https://intelliparadigm.com 第一章&#xff1a;VSCode 2026实时协作多人编辑的架构演进与设计哲学 VSCode 2026 将协作能力从插件生态升维至核心运行时层&#xff0c;其底层采用基于 CRDT&#xff08;Conflict-free Replicated Data Type&#xff09…...

Microsemi Libero SoC 实战:用Verilog写个LED呼吸灯,从仿真到上板全流程(附ModelSim波形分析)

Microsemi Libero SoC实战&#xff1a;Verilog实现LED呼吸灯的全流程解析 引言 呼吸灯效果在消费电子产品中极为常见&#xff0c;从笔记本电脑的睡眠指示灯到智能家居设备的待机状态提示&#xff0c;这种柔和的光线渐变效果远比简单的闪烁更富科技感和用户体验。对于FPGA开发…...

如何在 Chrome 浏览器中快速接入 Taotoken 并调用大模型 API

如何在 Chrome 浏览器中快速接入 Taotoken 并调用大模型 API 1. 准备工作 在开始之前&#xff0c;请确保您已经拥有 Taotoken 平台的 API Key。登录 Taotoken 控制台&#xff0c;在「API 密钥」页面可以创建和管理您的密钥。同时&#xff0c;建议在「模型广场」查看当前可用的…...

【紧急预警】大模型上线前必做的3项R统计审查:Feldman–Hajek偏差指数、Wasserstein公平距离、Bootstrap置信带校验

更多请点击&#xff1a; https://intelliparadigm.com 第一章&#xff1a;R语言在大语言模型偏见检测中的统计方法导论 在大语言模型&#xff08;LLM&#xff09;部署日益广泛的背景下&#xff0c;系统性偏见可能通过训练数据、词嵌入或生成逻辑被隐式放大。R语言凭借其强大的…...

Visual C++运行库终极修复指南:一键解决系统依赖问题的完整教程

Visual C运行库终极修复指南&#xff1a;一键解决系统依赖问题的完整教程 【免费下载链接】vcredist AIO Repack for latest Microsoft Visual C Redistributable Runtimes 项目地址: https://gitcode.com/gh_mirrors/vc/vcredist Visual C运行库是Windows系统中不可或缺…...

终极指南:让Mem Reduct内存优化工具显示中文界面的完整方案

终极指南&#xff1a;让Mem Reduct内存优化工具显示中文界面的完整方案 【免费下载链接】memreduct Lightweight real-time memory management application to monitor and clean system memory on your computer. 项目地址: https://gitcode.com/gh_mirrors/me/memreduct …...

告别视频消失焦虑:用m4s-converter永久保存你的B站收藏

告别视频消失焦虑&#xff1a;用m4s-converter永久保存你的B站收藏 【免费下载链接】m4s-converter 一个跨平台小工具&#xff0c;将bilibili缓存的m4s格式音视频文件合并成mp4 项目地址: https://gitcode.com/gh_mirrors/m4/m4s-converter 你是否曾经遇到过这样的情况&…...

用MSP432P401R驱动HX711压力传感器:从引脚配置到数据读取的保姆级代码解析

MSP432P401R与HX711压力传感器的深度开发指南 1. 硬件架构与通信原理 HX711是一款专为高精度称重传感器设计的24位模数转换器芯片&#xff0c;采用双线制串行通信协议。与MSP432P401R微控制器的配合使用&#xff0c;能够构建高性价比的称重系统解决方案。 核心引脚功能&#xf…...

java同步另一项目数据

java同步另一平台的数据 在 Java 中实现跨平台的数据同步&#xff0c;并没有唯一的标准答案&#xff0c;而是需要根据你的数据量大小、实时性要求以及对方平台提供的接口类型来选择合适的方案。 结合你的 Spring Boot MyBatis-Plus 技术栈&#xff0c;这里为你梳理了 4 种最主…...