本文为采用Vision TransformerViT的教学实例旨在提供教学示范和中小企业提供可以迅速投入使用的快速开发代码。本文为公益类代码由DeepSeek辅助生成经过实例测试。使用平台为SCNet详见https://blog.csdn.net/YucongCai/article/details/159696147?spm1001.2014.3001.5502https://blog.csdn.net/YucongCai/article/details/159696147?spm1001.2014.3001.5502本文采用了Vision Transformer basepretrained预训练模型它通过一个简单的CNN模块为projector 将244*244的模型分割成16*16的小块即14*14个输出embeding space 768 dimension以此作为Vision Transformer的基础输入随后通过Attention Mechanism进行一个classificaion的工作。本文会采用Vision Transformer在未经微调和经过微调的两种情况下对PASCAL VOC 2012 的image classification 进行一个示范。1. 下载PASCAL VOC 数据库并导入需要使用的libraries。import os import random import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader, Dataset from torchvision import datasets, transforms from torchvision.datasets import VOCDetection from PIL import Image from transformers import ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer from sklearn.metrics import accuracy_score import shutil # ------------------------------ # 1. Download prepare Pascal VOC 2012 # ------------------------------ data_dir ./data os.makedirs(data_dir, exist_okTrue) # Download the dataset (this may take a while) voc_train VOCDetection( rootdata_dir, year2012, image_settrainval, # use trainvalidation for a larger training set downloadTrue, transformNone ) # Save a few random images for later visualisation sample_dir ./sample_images os.makedirs(sample_dir, exist_okTrue) random_indices random.sample(range(len(voc_train)), 2) for i, idx in enumerate(random_indices): img, _ voc_train[idx] img.save(os.path.join(sample_dir, fsample_{i}.png)) print(fDataset ready. Total training images: {len(voc_train)})2. 准备数据给训练器。关于图像的transform这里采用 apply on the go 的边训练边transform方法虽然会增加一些overhead但对于大型数据库为常用方法尤其是在训练多发生于GPU而CPU通常有空余算力的情况下这也是transform function的默认方法即通过异构计算节约大量时间和内存成本。# ------------------------------ # 2. Build a dataset for multi-label classification # ------------------------------ # The 20 Pascal VOC classes (in the order used by the dataset) VOC_CLASSES [ aeroplane, bicycle, bird, boat, bottle, bus, car, cat, chair, cow, diningtable, dog, horse, motorbike, person, pottedplant, sheep, sofa, train, tvmonitor ] class VOCClassificationDataset(Dataset): def __init__(self, root, splittrainval, transformNone): self.dataset VOCDetection( rootroot, year2012, image_setsplit, downloadFalse, transformNone ) self.transform transform def __len__(self): return len(self.dataset) def __getitem__(self, idx): img, annotation self.dataset[idx] # Build a multi‑hot label vector label_vector torch.zeros(len(VOC_CLASSES), dtypetorch.float32) objects annotation[annotation][object] if isinstance(objects, dict): objects [objects] for obj in objects: class_name obj[name] if class_name in VOC_CLASSES: class_idx VOC_CLASSES.index(class_name) label_vector[class_idx] 1.0 if self.transform: img self.transform(img) return img, label_vector # Define preprocessing for ViT (normalisation) transform transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225]) ]) train_dataset VOCClassificationDataset(rootdata_dir, splittrainval, transformtransform) val_dataset VOCClassificationDataset(rootdata_dir, splitval, transformtransform) print(fTraining samples: {len(train_dataset)}) print(fValidation samples: {len(val_dataset)})3. 使用pretrained Vision Transformer对之前的例子进行了示范。需要注意的是这个pretrianed Vision Transformer base 模型是在ImageNet的1000个示例的基础上训练的而且它自身携带了id2label的vector to 文字的转化(model.config.id2label[idx_class])可以在一定程度上直接使用。 可以看到在未经过微调前Vision Transformer以及可以进行multilabel classification。然而它标注了图中比传统CNN网络要更加细致的信息虽然也包括主要结果。# ------------------------------ # Helper: denormalisation (used for display only) # ------------------------------ def denormalize(tensor, mean, std): Reverse the normalisation to get pixel values in [0,1]. tensor tensor.clone() for t, m, s in zip(tensor, mean, std): t.mul_(s).add_(m) return tensor.clamp_(0, 1) # ------------------------------ # 3. Load ViT‑Base with ORIGINAL ImageNet head, predict, THEN replace head # ------------------------------ device torch.device(cuda if torch.cuda.is_available() else cpu) print(fUsing device: {device}) # Load the pre‑trained model WITHOUT changing the head (keeps 1000-class ImageNet head) model ViTForImageClassification.from_pretrained(google/vit-base-patch16-224) # It contained the default id2label from the 1000 ImageNet classes model.to(device) model.eval() # Define the image processor processor ViTImageProcessor.from_pretrained(google/vit-base-patch16-224) # Pick two random images from the training set random_indices random.sample(range(len(train_dataset)), 2) # ----- PART A: Predict using the original ImageNet head ----- for i, idx in enumerate(random_indices): img_tensor, label_vector train_dataset[idx] # Already normalised (ImageNet stats) # Denormalise for correct visualisation img_denorm denormalize(img_tensor, mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225]) img_pil transforms.ToPILImage()(img_denorm) # Prepare input for the model inputs processor(imagesimg_pil, return_tensorspt).to(device) with torch.no_grad(): outputs model(**inputs) logits outputs.logits # Shape: [1, 1000] (ImageNet logits) probs torch.nn.functional.softmax(logits, dim-1).cpu().numpy()[0] # Show the image plt.figure(figsize(6,6)) plt.imshow(img_pil) plt.title(fOriginal ViT (ImageNet) prediction for sample {i1}) plt.axis(off) plt.show() # Get top‑5 ImageNet predictions top5_idx np.argsort(probs)[-5:][::-1] top5_labels [model.config.id2label[idx] for idx in top5_idx] top5_probs [probs[idx] for idx in top5_idx] # Ground truth VOC classes true_voc_classes [VOC_CLASSES[j] for j, val in enumerate(label_vector.numpy()) if val 1] print(fSample {i1}:) print(f True VOC classes: {true_voc_classes if true_voc_classes else none}) print(f Top‑5 ImageNet predictions:) for label, prob in zip(top5_labels, top5_probs): print(f {label}: {prob:.3f}) print(- * 50)随后将ViT 的 classifier 设置为PASCAL VOC的数据集的并开始微调。# ----- PART B: Replace the classifier head with a randomly initialised 20‑class head ----- print(Replacing the classifier head with a randomly initialised 20‑class head for Pascal VOC...) model.classifier nn.Linear(model.config.hidden_size, len(VOC_CLASSES)) # The new head is now randomly initialised (PyTorch default) model.to(device) # Now the model is ready for fine‑tuning on Pascal VOC (to be done in Block 4)4. 微调Vision Transformer以此快速定制其所承担的任务。可以看到训练准确度一开始就超过的大多数传统CNN构架的网络97%并在最终训练完成时达到了99%。# ------------------------------ # 4. Fine‑tuning the model # ------------------------------ # DataLoaders batch_size 16 train_loader DataLoader(train_dataset, batch_sizebatch_size, shuffleTrue, num_workers4) val_loader DataLoader(val_dataset, batch_sizebatch_size, shuffleFalse, num_workers4) # Optimiser and loss optimizer torch.optim.AdamW(model.parameters(), lr2e-5) criterion nn.BCEWithLogitsLoss() num_epochs 3 train_losses [] val_losses [] for epoch in range(num_epochs): model.train() total_loss 0.0 for images, labels in train_loader: images, labels images.to(device), labels.to(device) optimizer.zero_grad() outputs model(pixel_valuesimages).logits loss criterion(outputs, labels) loss.backward() optimizer.step() total_loss loss.item() * images.size(0) epoch_train_loss total_loss / len(train_loader.dataset) train_losses.append(epoch_train_loss) # Validation model.eval() total_val_loss 0.0 all_preds [] all_labels [] with torch.no_grad(): for images, labels in val_loader: images, labels images.to(device), labels.to(device) outputs model(pixel_valuesimages).logits loss criterion(outputs, labels) total_val_loss loss.item() * images.size(0) preds torch.sigmoid(outputs).cpu().numpy() all_preds.extend(preds) all_labels.extend(labels.cpu().numpy()) epoch_val_loss total_val_loss / len(val_loader.dataset) val_losses.append(epoch_val_loss) # Compute threshold‑based accuracy all_preds np.array(all_preds) 0.5 all_labels np.array(all_labels) accuracy (all_preds all_labels).mean() print(fEpoch {epoch1}/{num_epochs}) print(f Train loss: {epoch_train_loss:.4f} | Val loss: {epoch_val_loss:.4f}) print(f Val multi‑label accuracy: {accuracy:.4f}\n)5. 保存微调的模型并用为微调的模型对之前的例子两个随机测试。# ------------------------------ # 5. Save, reload, and test the fine‑tuned model # ------------------------------ # Save the fine‑tuned model model_save_path ./vit_pascal_voc_finetuned.pth torch.save(model.state_dict(), model_save_path) print(fModel saved to {model_save_path}) # Re‑load the model model_reloaded ViTForImageClassification.from_pretrained(google/vit-base-patch16-224) model_reloaded.classifier nn.Linear(model_reloaded.config.hidden_size, len(VOC_CLASSES)) model_reloaded.load_state_dict(torch.load(model_save_path, map_locationdevice)) model_reloaded.to(device) model_reloaded.eval() # Denormalisation function (already defined, but kept here for clarity) def denormalize(tensor, mean, std): tensor tensor.clone() for t, m, s in zip(tensor, mean, std): t.mul_(s).add_(m) return tensor.clamp_(0, 1) # Test on the same two images as before (random_indices from block 3) for i, idx in enumerate(random_indices): img_tensor, label_vector train_dataset[idx] # Already normalized # Denormalize to [0,1] before converting to PIL img_denorm denormalize(img_tensor, mean[0.485,0.456,0.406], std[0.229,0.224,0.225]) img_pil transforms.ToPILImage()(img_denorm) # Processor applies its own correct normalization inputs processor(imagesimg_pil, return_tensorspt).to(device) with torch.no_grad(): outputs model_reloaded(**inputs) logits outputs.logits predictions torch.sigmoid(logits).cpu().numpy()[0] plt.figure(figsize(6,6)) plt.imshow(img_pil) plt.title(fFine‑tuned prediction for sample {i1}) plt.axis(off) plt.show() predicted_classes [VOC_CLASSES[j] for j, prob in enumerate(predictions) if prob 0.5] true_classes [VOC_CLASSES[j] for j, val in enumerate(label_vector.numpy()) if val 1] print(fSample {i1} (after fine‑tuning):) print(f True classes: {true_classes if true_classes else none}) print(f Predicted classes: {predicted_classes if predicted_classes else none}\n)我在找工作HR或项目合作请联系yucongcai_businessoutlook.com与科研相关的请联系yucongcai_researchoutlook.com