这是一个用于服装分类的深度学习项目,使用 PyTorch 实现。
clothing-classification/
│
├── dataset.py # 数据集加载和处理
├── device.py # 设备选择
├── main.py # 主程序入口
├── map.py # 标签映射
├── model.py # 模型定义
├── test.py # 测试模型
├── train.py # 训练模型
└── README.md # 项目说明
- Python 3.11+
- PyTorch 2.5.1+
- torchvision
- pandas
- Pillow
可以使用以下命令安装依赖:
pip install -r requirements.txt本项目使用了一个服装分类数据集:Clothing dataset (full, high resolution)
请将数据集放置在 ./clothing-dataset-full 目录下,数据集应包含以下文件:
images.csv:包含图像文件名和标签的 CSV 文件images_original/:包含图像文件的目录
可以使用以下命令训练模型:
python main.py训练过程中会定期保存 checkpoint 文件 checkpoint.pth.tar,以便在中断后继续训练。
训练完成后,模型会自动进行测试,并输出测试集的准确率。
训练完成后,模型会自动保存到当前目录,文件名格式为 model_YYYYMMDDHHMMSS.pth。
如果训练过程中断,可以通过加载 checkpoint 文件继续训练。checkpoint 文件默认保存路径为 checkpoint.pth.tar。
如果有可用的 CUDA 设备,项目将自动使用 CUDA 加速。
如果 CUDA 不可用,可能是安装了 CPU 版本的 Pytorch。请考虑按照 指南 重新安装可用于 CUDA 加速的 PyTorch。
如果有多个 GPU,将自动使用 MPS(Multi-Process Service)来提高性能。
前往 dataset.py 文件,修改 num_workers、prefech_factor 变量来优化数据集加载性能。具体见代码注释。
下载模型文件:model
前往 test.py 文件,修改 model_path 变量为模型文件路径,然后运行以下命令:
python test.py