Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions DMF_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# !/usr/bin/env python

import glob
import os

import torch
from setuptools import find_packages
from setuptools import setup
from torch.utils.cpp_extension import CUDA_HOME
from torch.utils.cpp_extension import CppExtension
from torch.utils.cpp_extension import CUDAExtension

requirements = ["torch", "torchvision"]


def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "dconv", "csrc")

main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))

sources = main_file + source_cpu
extension = CppExtension

extra_compile_args = {"cxx": []}
define_macros = []

if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv("FORCE_CUDA", "0") == "1":
extension = CUDAExtension
sources += source_cuda
define_macros += [("WITH_CUDA", None)]
extra_compile_args["nvcc"] = [
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
]

sources = [os.path.join(extensions_dir, s) for s in sources]

include_dirs = [extensions_dir]

ext_modules = [
extension(
"dconv._C",
sources,
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
]

return ext_modules


setup(
name="dconv",
version="0.1",
author="fmassa",
url="https://github.com/facebookresearch/maskrcnn-benchmark",
description="object detection in pytorch",
packages=find_packages(
exclude=(
"configs",
"tests",
)
),
# install_requires=requirements,
ext_modules=get_extensions(),
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
)
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Wenbin Li, Ziyi Wang, Xuesong Yang, Chuanqi Dong, Pinzhuo Tian, Tiexin Qin, Jing
+ [FEAT (CVPR 2020)](http://arxiv.org/abs/1812.03664)
+ [RENet (ICCV 2021)](https://arxiv.org/abs/2108.09666)
+ [FRN (CVPR 2021)](https://arxiv.org/abs/2012.01506)
+ [DMF (CVPR 2021)](https://arxiv.org/pdf/2103.13582)
+ [DeepBDC (CVPR 2022)](https://arxiv.org/abs/2204.04567)
+ [CPEA (ICCV 2023)](https://openaccess.thecvf.com/content/ICCV2023/papers/Hao_Class-Aware_Patch_Embedding_Adaptation_for_Few-Shot_Image_Classification_ICCV_2023_paper.pdf)

Expand Down
39 changes: 39 additions & 0 deletions config/DMF.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
includes:
- headers/data.yaml
- headers/device.yaml
- headers/misc.yaml
- headers/model.yaml
- headers/optimizer.yaml

way_num: 5
shot_num: 1
query_num: 6
episode_size: 8
train_episode: 2000
test_episode: 1200

device_ids: 0,1,2,3
n_gpu: 4
epoch: 120

optimizer:
name: SGD
kwargs:
lr: 0.05
momentum: 0.9
nesterov: true
weight_decay: 0.0005
other: null

backbone:
name: resnet12_drop
kwargs:
drop_block: true

classifier:
name: DMF
kwargs:
num_class: 64
nFeat: 640
kernel: 1
groups: 64
4 changes: 4 additions & 0 deletions config/backbones/resnet12_drop.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
backbone:
name: resnet12_drop
kwargs:
drop_block: true
7 changes: 7 additions & 0 deletions config/classifiers/DMF.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
classifier:
name: DMF
kwargs:
num_class: 64
nFeat: 64
kernel: 3
groups: 1
1 change: 1 addition & 0 deletions core/model/backbone/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .swin_transformer import swin_s, swin_l, swin_b, swin_t, swin_mini
from .resnet_bdc import resnet12Bdc, resnet18Bdc
from core.model.backbone.utils.maml_module import convert_maml_module
from .resnet12_drop import resnet12_drop


def get_backbone(config):
Expand Down
Loading