CLIP 自定义数据集微调指南
本文档介绍如何使用自己的数据集微调 CLIP 模型,实现图像-文本对的语义匹配任务。
目录: 基本要求
完整流程
常见问题
参考资料
基本要求
环境配置
CUDA :11.8
cuDNN :8.9.3.28
Python :3.9.21
PyTorch :2.6.0
兼容性验证 已测试可行环境: RTX 5090 + CUDA 12.8 + Python 3.12.0 + PyTorch 2.7.0
安装依赖 1 2 3 4 5 6 git clone https://github.com/mlfoundations/open_clip/ cd open_clippip install -r requirements.txt
完整流程 数据准备 Step1:原始数据 将分类数据按以下目录结构存放:
1 2 3 4 5 6 7 8 9 Task/ # 任务根目录 ├── cls_1/ # 类别1(类别名即文件夹名) │ ├── img_001.jpg │ ├── img_002.jpg │ └── ... └── cls_2/ # 类别2 ├── img_001.jpg └── ...
注意:!!!文件夹名称即为类别名,将用于生成文本标签
Step2:划分训练集和验证集 使用cls_split_data.py脚本, 脚本将数据按比例划分为 train、val 集:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 """ 不均衡数据划分为train、val、test, 图片存储格式 Task/cls_name/image.jpg """ from pathlib import Pathimport shutilimport randomimport numpy as npdef cls_split_data (src_dir, output_root, train_ratio, val_ratio ): image_exts = {".jpg" , ".jpeg" , ".png" , ".bmp" , ".tif" , ".tiff" } for disease_dir in src_dir.rglob("*" ): if disease_dir.is_dir(): all_imgs = [p for p in disease_dir.glob("*" ) if p.suffix.lower() in image_exts] if not all_imgs: continue random.shuffle(all_imgs) n_total = len (all_imgs) n_train = int (np.floor(train_ratio * n_total)) n_val = int (np.ceil(val_ratio * n_total)) train_imgs = all_imgs[:n_train] val_imgs = all_imgs[n_train:n_train + n_val] test_imgs = all_imgs[n_train + n_val:] rel_path = disease_dir.relative_to(src_dir) for split_name, split_imgs in [("train" , train_imgs), ("val" , val_imgs)]: for img in split_imgs: target_path = output_root / split_name / rel_path / img.name target_path.parent.mkdir(parents=True , exist_ok=True ) shutil.copy(img, target_path) print ("按层级划分完成。" ) if __name__ == "__main__" : src_dir = Path("/xxx/open_clip/datasets/Task" ) output_root = Path("/xxx/open_clip/datasets/Task_split" ) train_ratio, val_ratio = 0.8 , 0.2 cls_split_data(src_dir, output_root, train_ratio, val_ratio)
执行脚本 :
1 python cls_split_data.py
输出结果 :
1 2 3 4 5 6 7 8 9 Task_split/ # 输出根目录 ├── train/ # 训练集 (80%) │ ├── cls_1/ │ │ └── img_001.jpg │ └── cls_2/ │ └── img_001.jpg └── val/ # 验证集 (20%) ├── cls_1/ └── cls_2/
Step 3: (可选) 合并已有数据集 如果需要合并多个数据集:
1 2 mkdir target_folder cp -r folder1/* folder2/* target_folder/
Step 4: 生成 CSV 训练文件 CLIP 模型使用 CSV 格式存储图像-文本对。运行 file2caption.py 脚本生成训练所需的 CSV 文件。
file2caption.py脚本:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 """ # 1、生成训练用的csv # 2、生成所有caption列表保存到json文件下 # 3、得到caption和对应的病害中文名 """ from pathlib import Pathimport randomimport pandas as pdimport csvimport jsonfrom tqdm import tqdmfrom collections import defaultdictdef format_semantics (root_folder, train_csv_file, val_csv_file, cls_text_json, train_rate=0.8 ): """ :purpose: generate clip image-text semantice pair and save csv :param root_folder: dataset root path :param train_csv_file: train csv save path :param val_csv_file: val csv save path :return: csv """ image_files = [] image_extensions = {".jpg" , ".jpeg" , ".png" , ".gif" , ".bmp" , ".tiff" } for file in root_folder.rglob('*' ): if file.is_file() and (file.suffix.lower() in image_extensions): subfolder_name = file.parent.name image_files.append((file, subfolder_name)) random.shuffle(image_files) subfolder_dict = defaultdict(list ) for file, subfolder_name in image_files: subfolder_dict[subfolder_name].append(file) train_files = [] val_files = [] captions_list = [] for subfolder_name, files in subfolder_dict.items(): random.shuffle(files) num_train = int (train_rate * len (files)) train_files.extend(files[:num_train]) val_files.extend(files[num_train:]) train_data = [] val_data = [] for file in train_files: new_name = f"{file.parent} /{file.name} " caption = f"A photo of {file.parent.name} " train_data.append([new_name, caption]) if caption not in captions_list: captions_list.append(caption) for file in val_files: new_name = f"{file.parent} /{file.name} " caption = f"A photo of a {file.parent.name} " val_data.append([new_name, caption]) if caption not in captions_list: captions_list.append(caption) random.shuffle(train_data) random.shuffle(val_data) with open (train_csv_file, mode='w' , newline='' ) as file: writer = csv.writer(file) writer.writerow(['Image' , 'Caption' ]) writer.writerows(train_data) with open (val_csv_file, mode='w' , newline='' ) as file: writer = csv.writer(file) writer.writerow(['Image' , 'Caption' ]) writer.writerows(val_data) with open (cls_text_json, 'w' , encoding='utf-8' ) as f: json.dump(captions_list, f) print (f"训练集已保存到 {train_csv_file} " ) print (f"验证集已保存到 {val_csv_file} " ) if __name__ == "__main__" : root_folder = Path('train_data/xxx' ) train_csv_file = 'train_data/xxx.csv' val_csv_file = 'train_data/tmp.csv' cls_text_json = 'train_data/cls_text.json' train_rate = 1 format_semantics(root_folder, train_csv_file, val_csv_file, cls_text_json, train_rate)
执行命令 :
1 python utils/file2caption.py
生成的CSV格式示例 :
1 2 3 4 Image,Caption datasets/xxx/DJI_20250707144724_0003_35.jpg,A photo of xxx. datasets/xxx/org_a9e2f7585c07a6f5_1752478690000.jpg,A photo of xxx. datasets/xxx/139608.jpg",Not a photo of xxx.
CSV 文件说明 :
第一行:列标识(Image, Caption),无需修改
后续行:每行包含图片相对路径和对应的文本标签
Caption 格式可自定义(如 “A photo of {class_name}”)
模型训练: Step 5: 执行训练 使用 train.sh 脚本启动训练:
训练脚本 (train.sh):
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 #!/bin/bash export CUDA_VISIBLE_DEVICES=1nohup torchrun --nproc_per_node 1 -m src.open_clip_train.main \ --batch-size 256 \ --precision amp \ --workers 4 \ --save-frequency 10 \ --dataset-type csv \ --csv-separator="," \ --train-data datasets/Task_split/train.csv \ --val-data datasets/Task_split/val.csv \ --val-frequency 5 \ --csv-img-key Image \ --csv-caption-key Caption \ --warmup 1000 \ --lr=5e-6 \ --wd=0.1 \ --epochs=100 \ --model ViT-B-32 \ --pretrained weights/clip_model/open_clip_pytorch_model.bin \ --grad-checkpointing \ --device "cuda" > nohup.log 2>&1 &
关键参数说明
参数
说明
推荐值
--train-data
训练集CSV路径
Step 4生成的train.csv
--val-data
验证集CSV路径
Step 4生成的val.csv
--csv-img-key
CSV中图片列名
Image
--csv-caption-key
CSV中文本列名
Caption
--batch-size
批次大小
256 (根据显存调整)
--lr
学习率
5e-6 (微调推荐)
--epochs
训练轮数
100
--model
模型架构
ViT-B-32 / ViT-L-14
--pretrained
预训练权重路径
本地模型路径
--precision
混合精度训练
amp
--grad-checkpointing
梯度检查点
启用可节省显存
启动训练
训练日志将保存到 nohup.log 文件中。
模型测试 Step 6: 推理测试 使用训练好的模型进行推理测试:
测试脚本 (inference.py):
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 import torchfrom PIL import Imageimport open_clipimport timemodel, _, preprocess = open_clip.create_model_and_transforms( 'ViT-B-32' , pretrained='/path/to/your/finetuned_model.bin' ) model.eval () tokenizer = open_clip.get_tokenizer('ViT-B-32' ) image = preprocess(Image.open ("test_img/sample.jpg" )).unsqueeze(0 ) text_labels = [ "A photo of cls_1" , "A photo of cls_2" , "A photo of cls_3" ] text = tokenizer(text_labels) start = time.time() with torch.no_grad(), torch.cuda.amp.autocast(): image_features = model.encode_image(image) text_features = model.encode_text(text) image_features /= image_features.norm(dim=-1 , keepdim=True ) text_features /= text_features.norm(dim=-1 , keepdim=True ) text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1 ) end = time.time() print (f"推理时间: {end - start:.4 f} s" )print (f"类别概率: {text_probs} " )print (f"\n预测结果:" )for label, prob in zip (text_labels, text_probs[0 ]): print (f" {label} : {prob.item():.4 f} " )
运行推理
输出示例 :
1 2 3 4 5 6 7 推理时间: 0.0234s 类别概率: tensor([[0.8520, 0.1230, 0.0250]]) 预测结果: A photo of cls_1: 0.8520 A photo of cls_2: 0.1230 A photo of cls_3: 0.0250
常见问题 1. 显存不足 (OOM) 解决方案 :
减小 --batch-size (如 256 → 128 → 64)
启用 --grad-checkpointing
使用更小的模型 (ViT-B-32 代替 ViT-L-14)
2. 训练速度慢 优化建议 :
增加 --workers 数量(数据加载线程)
使用 --precision amp 混合精度训练
3. 模型不收敛 调试步骤 :
检查学习率是否过大/过小 (推荐 5e-6 ~ 1e-5)
增加 --warmup 步数
确认 CSV 文件格式正确
检查数据集质量和标签准确性
4. Caption 设计建议 推荐格式 :
1 2 3 4 5 6 7 8 "A photo of {class_name}" "This is a photo of {class_name}, which is a type of {category}" "Not a photo of {class_name}"
5. 如何选择预训练模型
模型
参数量
性能
显存需求
适用场景
ViT-B-32
151M
中
~12GB
快速实验
ViT-B-16
149M
较高
~16GB
平衡性能和速度
ViT-L-14
427M
高
~24GB
追求最佳性能
参考资料:
版本历史