Skip to content

Commit 3b6fed1

Browse files
committed
abcnetv2 train
1 parent bf41194 commit 3b6fed1

13 files changed

+188
-24
lines changed

projects/ABCNet/abcnet/model/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .bifpn import BiFPN
1313
from .coordinate_head import CoordinateHead
1414
from .rec_roi_head import RecRoIHead
15+
from .task_utils import * # noqa: F401,F403
1516

1617
__all__ = [
1718
'ABCNetDetHead', 'ABCNetDetPostprocessor', 'ABCNetRecBackbone',

projects/ABCNet/abcnet/model/abcnet_det_head.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,8 @@ def forward_single(self, x, scale, stride):
181181
# float to avoid overflow when enabling FP16
182182
if self.use_scale:
183183
bbox_pred = scale(bbox_pred).float()
184-
else:
185-
bbox_pred = bbox_pred.float()
184+
# else:
185+
# bbox_pred = bbox_pred.float()
186186
if self.norm_on_bbox:
187187
# bbox_pred needed for gradient computation has been modified
188188
# by F.relu(bbox_pred) when run with PyTorch 1.10. So replace

projects/ABCNet/abcnet/model/abcnet_det_postprocessor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,8 @@ def __call__(self, pred_results, data_samples, training: bool = False):
216216
Returns:
217217
list[TextDetDataSample]: Batch of post-processed datasamples.
218218
"""
219-
if training:
220-
return data_samples
219+
# if training:
220+
# return data_samples
221221
cfg = self.train_cfg if training else self.test_cfg
222222
if cfg is None:
223223
cfg = {}

projects/ABCNet/abcnet/model/bezier_roi_extractor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def forward(self, feats: Tuple[Tensor],
8888
# convert fp32 to fp16 when amp is on
8989
rois = rois.type_as(feats[0])
9090
out_size = self.roi_layers[0].output_size
91-
feats = feats[:3]
91+
# feats = feats[:3]
9292
num_levels = len(feats)
9393
roi_feats = feats[0].new_zeros(
9494
rois.size(0), self.out_channels, *out_size)

projects/ABCNet/abcnet/model/bifpn.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,12 @@ def __init__(self,
170170
self.bifpn_convs = nn.ModuleList()
171171
# weighted
172172
self.weight_two_nodes = nn.Parameter(
173-
torch.Tensor(2, levels).fill_(init))
173+
torch.Tensor(2, levels).fill_(init), requires_grad=True)
174+
174175
self.weight_three_nodes = nn.Parameter(
175-
torch.Tensor(3, levels - 2).fill_(init))
176-
self.relu = nn.ReLU()
176+
torch.Tensor(3, levels - 2).fill_(init), requires_grad=True)
177+
178+
# self.relu = nn.ReLU(inplace=False)
177179
for _ in range(2):
178180
for _ in range(self.levels - 1): # 1,2,3
179181
fpn_conv = nn.Sequential(
@@ -193,9 +195,10 @@ def forward(self, inputs):
193195
# build top-down and down-top path with stack
194196
levels = self.levels
195197
# w relu
196-
w1 = self.relu(self.weight_two_nodes)
197-
w1 /= torch.sum(w1, dim=0) + self.eps # normalize
198-
w2 = self.relu(self.weight_three_nodes)
198+
199+
_w1 = F.relu(self.weight_two_nodes)
200+
w1 = _w1 / (torch.sum(_w1, dim=0) + self.eps) # normalize
201+
w2 = F.relu(self.weight_three_nodes)
199202
# w2 /= torch.sum(w2, dim=0) + self.eps # normalize
200203
# build top-down
201204
idx_bifpn = 0

projects/ABCNet/abcnet/model/rec_roi_head.py

+41-7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
from typing import Tuple
2+
from typing import Optional, Sequence, Tuple
33

44
from mmengine.structures import LabelData
55
from torch import Tensor
@@ -15,14 +15,19 @@ class RecRoIHead(BaseRoIHead):
1515
"""Simplest base roi head including one bbox head and one mask head."""
1616

1717
def __init__(self,
18-
neck=None,
18+
inputs_indices: Optional[Sequence] = None,
19+
neck: OptMultiConfig = None,
20+
assigner: OptMultiConfig = None,
1921
sampler: OptMultiConfig = None,
2022
roi_extractor: OptMultiConfig = None,
2123
rec_head: OptMultiConfig = None,
2224
init_cfg=None):
2325
super().__init__(init_cfg)
24-
if sampler is not None:
25-
self.sampler = TASK_UTILS.build(sampler)
26+
self.inputs_indices = inputs_indices
27+
self.assigner = assigner
28+
if assigner is not None:
29+
self.assigner = TASK_UTILS.build(assigner)
30+
self.sampler = TASK_UTILS.build(sampler)
2631
if neck is not None:
2732
self.neck = MODELS.build(neck)
2833
self.roi_extractor = MODELS.build(roi_extractor)
@@ -43,11 +48,39 @@ def loss(self, inputs: Tuple[Tensor], data_samples: DetSampleList) -> dict:
4348
Returns:
4449
dict[str, Tensor]: A dictionary of loss components
4550
"""
46-
proposals = [
47-
ds.gt_instances[~ds.gt_instances.ignored] for ds in data_samples
48-
]
51+
52+
if self.inputs_indices is not None:
53+
inputs = [inputs[i] for i in self.inputs_indices]
54+
# proposals = [
55+
# ds.gt_instances[~ds.gt_instances.ignored] for ds in data_samples
56+
# ]
57+
proposals = list()
58+
for ds in data_samples:
59+
pred_instances = ds.pred_instances
60+
gt_instances = ds.gt_instances
61+
# # assign
62+
# gt_beziers = gt_instances.beziers
63+
# pred_beziers = pred_instances.beziers
64+
# assign_index = [
65+
# int(
66+
# torch.argmin(
67+
# torch.abs(gt_beziers - pred_beziers[i]).sum(dim=1)))
68+
# for i in range(len(pred_beziers))
69+
# ]
70+
# proposal = InstanceData()
71+
# proposal.texts = gt_instances.texts + gt_instances[
72+
# assign_index].texts
73+
# proposal.beziers = torch.cat(
74+
# [gt_instances.beziers, pred_instances.beziers], dim=0)
75+
if self.assigner:
76+
gt_instances, pred_instances = self.assigner.assign(
77+
gt_instances, pred_instances)
78+
proposal = self.sampler.sample(gt_instances, pred_instances)
79+
proposals.append(proposal)
4980

5081
proposals = [p for p in proposals if len(p) > 0]
82+
if hasattr(self, 'neck') and self.neck is not None:
83+
inputs = self.neck(inputs)
5184
bbox_feats = self.roi_extractor(inputs, proposals)
5285
rec_data_samples = [
5386
TextRecogDataSample(gt_text=LabelData(item=text))
@@ -57,6 +90,7 @@ def loss(self, inputs: Tuple[Tensor], data_samples: DetSampleList) -> dict:
5790

5891
def predict(self, inputs: Tuple[Tensor],
5992
data_samples: DetSampleList) -> RecSampleList:
93+
inputs = inputs[:3]
6094
if hasattr(self, 'neck') and self.neck is not None:
6195
inputs = self.neck(inputs)
6296
pred_instances = [ds.pred_instances for ds in data_samples]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from .assigner import L1DistanceAssigner
3+
from .sampler import ConcatSampler, OnlyGTSampler
4+
5+
__all__ = ['L1DistanceAssigner', 'ConcatSampler', 'OnlyGTSampler']
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import torch
3+
4+
from mmocr.registry import TASK_UTILS
5+
6+
7+
@TASK_UTILS.register_module()
8+
class L1DistanceAssigner:
9+
10+
def assign(self, gt_instances, pred_instances):
11+
gt_beziers = gt_instances.beziers
12+
pred_beziers = pred_instances.beziers
13+
assign_index = [
14+
int(
15+
torch.argmin(
16+
torch.abs(gt_beziers - pred_beziers[i]).sum(dim=1)))
17+
for i in range(len(pred_beziers))
18+
]
19+
pred_instances.assign_index = assign_index
20+
return gt_instances, pred_instances
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import torch
3+
from mmengine.structures import InstanceData
4+
5+
from mmocr.registry import TASK_UTILS
6+
7+
8+
@TASK_UTILS.register_module()
9+
class ConcatSampler:
10+
11+
def sample(self, gt_instances, pred_instances):
12+
if len(pred_instances) == 0:
13+
return gt_instances
14+
proposals = InstanceData()
15+
proposals.texts = gt_instances.texts + gt_instances[
16+
pred_instances.assign_index].texts
17+
proposals.beziers = torch.cat(
18+
[gt_instances.beziers, pred_instances.beziers], dim=0)
19+
return proposals
20+
21+
22+
@TASK_UTILS.register_module()
23+
class OnlyGTSampler:
24+
25+
def sample(self, gt_instances, pred_instances):
26+
return gt_instances[~gt_instances.ignored]

projects/ABCNet/config/_base_/default_runtime.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
default_scope = 'mmocr'
22
env_cfg = dict(
3-
cudnn_benchmark=True,
3+
cudnn_benchmark=False,
44
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
55
dist_cfg=dict(backend='nccl'),
66
)

projects/ABCNet/config/_base_/schedules/schedule_sgd_500e.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
type='OptimWrapper',
44
optimizer=dict(type='SGD', lr=0.001, momentum=0.9, weight_decay=0.0001),
55
clip_grad=dict(type='value', clip_value=1))
6-
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=500, val_interval=20)
6+
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=500, val_interval=10)
77
val_cfg = dict(type='ValLoop')
88
test_cfg = dict(type='TestLoop')
99
# learning policy

projects/ABCNet/config/abcnet_v2/_base_abcnet-v2_resnet50_bifpn.py

+64-3
Original file line numberDiff line numberDiff line change
@@ -68,21 +68,46 @@
6868
std=0.01,
6969
bias=-4.59511985013459), # -log((1-p)/p) where p=0.01
7070
),
71-
module_loss=None,
71+
module_loss=dict(
72+
type='ABCNetDetModuleLoss',
73+
num_classes=num_classes,
74+
strides=strides,
75+
center_sampling=True,
76+
center_sample_radius=1.5,
77+
bbox_coder=bbox_coder,
78+
norm_on_bbox=norm_on_bbox,
79+
loss_cls=dict(
80+
type='mmdet.FocalLoss',
81+
use_sigmoid=use_sigmoid_cls,
82+
gamma=2.0,
83+
alpha=0.25,
84+
loss_weight=1.0),
85+
loss_bbox=dict(type='mmdet.GIoULoss', loss_weight=1.0),
86+
loss_centerness=dict(
87+
type='mmdet.CrossEntropyLoss',
88+
use_sigmoid=True,
89+
loss_weight=1.0)),
7290
postprocessor=dict(
7391
type='ABCNetDetPostprocessor',
7492
# rescale_fields=['polygons', 'bboxes'],
7593
use_sigmoid_cls=use_sigmoid_cls,
7694
strides=[8, 16, 32, 64, 128],
7795
bbox_coder=dict(type='mmdet.DistancePointBBoxCoder'),
7896
with_bezier=True,
97+
train_cfg=dict(
98+
# rescale_fields=['polygon', 'bboxes', 'bezier'],
99+
nms_pre=1000,
100+
nms=dict(type='nms', iou_threshold=0.4),
101+
score_thr=0.7),
79102
test_cfg=dict(
80103
# rescale_fields=['polygon', 'bboxes', 'bezier'],
81104
nms_pre=1000,
82105
nms=dict(type='nms', iou_threshold=0.4),
83-
score_thr=0.3))),
106+
score_thr=0.4))),
84107
roi_head=dict(
85108
type='RecRoIHead',
109+
assigner=dict(type='L1DistanceAssigner'),
110+
sampler=dict(type='ConcatSampler'),
86111
neck=dict(type='CoordinateHead'),
87112
roi_extractor=dict(
88113
type='BezierRoIExtractor',
@@ -97,7 +122,14 @@
97122
decoder=dict(
98123
type='ABCNetRecDecoder',
99124
dictionary=dictionary,
100-
postprocessor=dict(type='AttentionPostprocessor'),
125+
postprocessor=dict(
126+
type='AttentionPostprocessor',
127+
ignore_chars=['padding', 'unknown']),
128+
module_loss=dict(
129+
type='CEModuleLoss',
130+
ignore_first_char=False,
131+
ignore_char=-1,
132+
reduction='mean'),
101133
max_seq_len=25))),
102134
postprocessor=dict(
103135
type='ABCNetPostprocessor',
@@ -120,3 +152,32 @@
120152
type='PackTextDetInputs',
121153
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
122154
]
155+
156+
train_pipeline = [
157+
dict(
158+
type='LoadImageFromFile',
159+
file_client_args=file_client_args,
160+
color_type='color_ignore_orientation'),
161+
dict(
162+
type='LoadOCRAnnotations',
163+
with_polygon=True,
164+
with_bbox=True,
165+
with_label=True,
166+
with_text=True),
167+
dict(type='RemoveIgnored'),
168+
dict(type='RandomCrop', min_side_ratio=0.1),
169+
dict(
170+
type='RandomRotate',
171+
max_angle=30,
172+
pad_with_fixed_color=True,
173+
use_canvas=True),
174+
dict(
175+
type='RandomChoiceResize',
176+
scales=[(980, 2900), (1044, 2900), (1108, 2900), (1172, 2900),
177+
(1236, 2900), (1300, 2900), (1364, 2900), (1428, 2900),
178+
(1492, 2900)],
179+
keep_ratio=True),
180+
dict(
181+
type='PackTextDetInputs',
182+
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
183+
]

projects/ABCNet/config/abcnet_v2/abcnet-v2_resnet50_bifpn_500e_icdar2015.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,21 @@
22
'_base_abcnet-v2_resnet50_bifpn.py',
33
'../_base_/datasets/icdar2015.py',
44
'../_base_/default_runtime.py',
5+
'../_base_/schedules/schedule_sgd_500e.py',
56
]
67

78
# dataset settings
9+
icdar2015_textspotting_train = _base_.icdar2015_textspotting_train
10+
icdar2015_textspotting_train.pipeline = _base_.train_pipeline
811
icdar2015_textspotting_test = _base_.icdar2015_textspotting_test
912
icdar2015_textspotting_test.pipeline = _base_.test_pipeline
1013

14+
train_dataloader = dict(
15+
batch_size=1,
16+
num_workers=8,
17+
persistent_workers=True,
18+
sampler=dict(type='DefaultSampler', shuffle=False),
19+
dataset=icdar2015_textspotting_train)
1120
val_dataloader = dict(
1221
batch_size=1,
1322
num_workers=4,
@@ -20,4 +29,9 @@
2029
val_cfg = dict(type='ValLoop')
2130
test_cfg = dict(type='TestLoop')
2231

23-
custom_imports = dict(imports=['abcnet'], allow_failed_imports=False)
32+
custom_imports = dict(
33+
imports=['projects.ABCNet.abcnet'], allow_failed_imports=False)
34+
35+
load_from = 'checkpoints/abcnet-v2_resnet50_bifpn_500e_pretrain.pth'
36+
37+
find_unused_parameters = True

0 commit comments

Comments
 (0)