Skip to content

Commit 2fe3c41

Browse files
committed
abcnetv2 train
1 parent bf41194 commit 2fe3c41

14 files changed

+203
-30
lines changed

projects/ABCNet/abcnet/model/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .bifpn import BiFPN
1313
from .coordinate_head import CoordinateHead
1414
from .rec_roi_head import RecRoIHead
15+
from .task_utils import * # noqa: F401,F403
1516

1617
__all__ = [
1718
'ABCNetDetHead', 'ABCNetDetPostprocessor', 'ABCNetRecBackbone',

projects/ABCNet/abcnet/model/abcnet_det_head.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class ABCNetDetHead(BaseTextDetHead):
1414

1515
def __init__(self,
1616
in_channels,
17-
module_loss=dict(type='ABCNetLoss'),
17+
module_loss=dict(type='ABCNetDetModuleLoss'),
1818
postprocessor=dict(type='ABCNetDetPostprocessor'),
1919
num_classes=1,
2020
strides=(4, 8, 16, 32, 64),
@@ -181,8 +181,8 @@ def forward_single(self, x, scale, stride):
181181
# float to avoid overflow when enabling FP16
182182
if self.use_scale:
183183
bbox_pred = scale(bbox_pred).float()
184-
else:
185-
bbox_pred = bbox_pred.float()
184+
# else:
185+
# bbox_pred = bbox_pred.float()
186186
if self.norm_on_bbox:
187187
# bbox_pred needed for gradient computation has been modified
188188
# by F.relu(bbox_pred) when run with PyTorch 1.10. So replace

projects/ABCNet/abcnet/model/abcnet_det_module_loss.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Dict, List, Tuple
33

44
import torch
5+
import torch.nn.functional as F
56
from mmdet.models.task_modules.prior_generators import MlvlPointGenerator
67
from mmdet.models.utils import multi_apply
78
from mmdet.utils import reduce_mean
@@ -149,11 +150,17 @@ def forward(self, inputs: Tuple[Tensor],
149150
avg_factor=centerness_denorm)
150151
loss_centerness = self.loss_centerness(
151152
pos_centerness, pos_centerness_targets, avg_factor=num_pos)
152-
loss_bezier = self.loss_bezier(
153-
pos_bezier_preds,
154-
pos_bezier_targets,
155-
weight=pos_centerness_targets[:, None],
156-
avg_factor=centerness_denorm)
153+
# loss_bezier = self.loss_bezier(
154+
# pos_bezier_preds,
155+
# pos_bezier_targets,
156+
# weight=pos_centerness_targets[:, None],
157+
# avg_factor=centerness_denorm)
158+
159+
loss_bezier = F.smooth_l1_loss(
160+
pos_bezier_preds, pos_bezier_targets, reduction='none')
161+
loss_bezier = (
162+
(loss_bezier.mean(dim=-1) * pos_centerness_targets).sum() /
163+
centerness_denorm)
157164
else:
158165
loss_bbox = pos_bbox_preds.sum()
159166
loss_centerness = pos_centerness.sum()
@@ -250,6 +257,7 @@ def _get_targets_single(self, data_sample: TextDetDataSample,
250257
polygons = gt_instances.polygons
251258
beziers = gt_bboxes.new([poly2bezier(poly) for poly in polygons])
252259
gt_instances.beziers = beziers
260+
# beziers = gt_instances.beziers
253261
if num_gts == 0:
254262
return gt_labels.new_full((num_points,), self.num_classes), \
255263
gt_bboxes.new_zeros((num_points, 4)), \

projects/ABCNet/abcnet/model/abcnet_det_postprocessor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,8 @@ def __call__(self, pred_results, data_samples, training: bool = False):
216216
Returns:
217217
list[TextDetDataSample]: Batch of post-processed datasamples.
218218
"""
219-
if training:
220-
return data_samples
219+
# if training:
220+
# return data_samples
221221
cfg = self.train_cfg if training else self.test_cfg
222222
if cfg is None:
223223
cfg = {}

projects/ABCNet/abcnet/model/bezier_roi_extractor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def forward(self, feats: Tuple[Tensor],
8888
# convert fp32 to fp16 when amp is on
8989
rois = rois.type_as(feats[0])
9090
out_size = self.roi_layers[0].output_size
91-
feats = feats[:3]
91+
# feats = feats[:3]
9292
num_levels = len(feats)
9393
roi_feats = feats[0].new_zeros(
9494
rois.size(0), self.out_channels, *out_size)

projects/ABCNet/abcnet/model/bifpn.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,12 @@ def __init__(self,
170170
self.bifpn_convs = nn.ModuleList()
171171
# weighted
172172
self.weight_two_nodes = nn.Parameter(
173-
torch.Tensor(2, levels).fill_(init))
173+
torch.Tensor(2, levels).fill_(init), requires_grad=True)
174+
174175
self.weight_three_nodes = nn.Parameter(
175-
torch.Tensor(3, levels - 2).fill_(init))
176-
self.relu = nn.ReLU()
176+
torch.Tensor(3, levels - 2).fill_(init), requires_grad=True)
177+
178+
# self.relu = nn.ReLU(inplace=False)
177179
for _ in range(2):
178180
for _ in range(self.levels - 1): # 1,2,3
179181
fpn_conv = nn.Sequential(
@@ -193,9 +195,10 @@ def forward(self, inputs):
193195
# build top-down and down-top path with stack
194196
levels = self.levels
195197
# w relu
196-
w1 = self.relu(self.weight_two_nodes)
197-
w1 /= torch.sum(w1, dim=0) + self.eps # normalize
198-
w2 = self.relu(self.weight_three_nodes)
198+
199+
_w1 = F.relu(self.weight_two_nodes)
200+
w1 = _w1 / (torch.sum(_w1, dim=0) + self.eps) # normalize
201+
w2 = F.relu(self.weight_three_nodes)
199202
# w2 /= torch.sum(w2, dim=0) + self.eps # normalize
200203
# build top-down
201204
idx_bifpn = 0

projects/ABCNet/abcnet/model/rec_roi_head.py

+41-7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
from typing import Tuple
2+
from typing import Optional, Sequence, Tuple
33

44
from mmengine.structures import LabelData
55
from torch import Tensor
@@ -15,14 +15,19 @@ class RecRoIHead(BaseRoIHead):
1515
"""Simplest base roi head including one bbox head and one mask head."""
1616

1717
def __init__(self,
18-
neck=None,
18+
inputs_indices: Optional[Sequence] = None,
19+
neck: OptMultiConfig = None,
20+
assigner: OptMultiConfig = None,
1921
sampler: OptMultiConfig = None,
2022
roi_extractor: OptMultiConfig = None,
2123
rec_head: OptMultiConfig = None,
2224
init_cfg=None):
2325
super().__init__(init_cfg)
24-
if sampler is not None:
25-
self.sampler = TASK_UTILS.build(sampler)
26+
self.inputs_indices = inputs_indices
27+
self.assigner = assigner
28+
if assigner is not None:
29+
self.assigner = TASK_UTILS.build(assigner)
30+
self.sampler = TASK_UTILS.build(sampler)
2631
if neck is not None:
2732
self.neck = MODELS.build(neck)
2833
self.roi_extractor = MODELS.build(roi_extractor)
@@ -43,11 +48,39 @@ def loss(self, inputs: Tuple[Tensor], data_samples: DetSampleList) -> dict:
4348
Returns:
4449
dict[str, Tensor]: A dictionary of loss components
4550
"""
46-
proposals = [
47-
ds.gt_instances[~ds.gt_instances.ignored] for ds in data_samples
48-
]
51+
52+
if self.inputs_indices is not None:
53+
inputs = [inputs[i] for i in self.inputs_indices]
54+
# proposals = [
55+
# ds.gt_instances[~ds.gt_instances.ignored] for ds in data_samples
56+
# ]
57+
proposals = list()
58+
for ds in data_samples:
59+
pred_instances = ds.pred_instances
60+
gt_instances = ds.gt_instances
61+
# # assign
62+
# gt_beziers = gt_instances.beziers
63+
# pred_beziers = pred_instances.beziers
64+
# assign_index = [
65+
# int(
66+
# torch.argmin(
67+
# torch.abs(gt_beziers - pred_beziers[i]).sum(dim=1)))
68+
# for i in range(len(pred_beziers))
69+
# ]
70+
# proposal = InstanceData()
71+
# proposal.texts = gt_instances.texts + gt_instances[
72+
# assign_index].texts
73+
# proposal.beziers = torch.cat(
74+
# [gt_instances.beziers, pred_instances.beziers], dim=0)
75+
if self.assigner:
76+
gt_instances, pred_instances = self.assigner.assign(
77+
gt_instances, pred_instances)
78+
proposal = self.sampler.sample(gt_instances, pred_instances)
79+
proposals.append(proposal)
4980

5081
proposals = [p for p in proposals if len(p) > 0]
82+
if hasattr(self, 'neck') and self.neck is not None:
83+
inputs = self.neck(inputs)
5184
bbox_feats = self.roi_extractor(inputs, proposals)
5285
rec_data_samples = [
5386
TextRecogDataSample(gt_text=LabelData(item=text))
@@ -57,6 +90,7 @@ def loss(self, inputs: Tuple[Tensor], data_samples: DetSampleList) -> dict:
5790

5891
def predict(self, inputs: Tuple[Tensor],
5992
data_samples: DetSampleList) -> RecSampleList:
93+
inputs = inputs[:3]
6094
if hasattr(self, 'neck') and self.neck is not None:
6195
inputs = self.neck(inputs)
6296
pred_instances = [ds.pred_instances for ds in data_samples]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from .assigner import L1DistanceAssigner
3+
from .sampler import ConcatSampler, OnlyGTSampler
4+
5+
__all__ = ['L1DistanceAssigner', 'ConcatSampler', 'OnlyGTSampler']
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import torch
3+
4+
from mmocr.registry import TASK_UTILS
5+
6+
7+
@TASK_UTILS.register_module()
8+
class L1DistanceAssigner:
9+
10+
def assign(self, gt_instances, pred_instances):
11+
gt_beziers = gt_instances.beziers
12+
pred_beziers = pred_instances.beziers
13+
assign_index = [
14+
int(
15+
torch.argmin(
16+
torch.abs(gt_beziers - pred_beziers[i]).sum(dim=1)))
17+
for i in range(len(pred_beziers))
18+
]
19+
pred_instances.assign_index = assign_index
20+
return gt_instances, pred_instances
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import torch
3+
from mmengine.structures import InstanceData
4+
5+
from mmocr.registry import TASK_UTILS
6+
7+
8+
@TASK_UTILS.register_module()
9+
class ConcatSampler:
10+
11+
def sample(self, gt_instances, pred_instances):
12+
if len(pred_instances) == 0:
13+
return gt_instances
14+
proposals = InstanceData()
15+
proposals.texts = gt_instances.texts + gt_instances[
16+
pred_instances.assign_index].texts
17+
proposals.beziers = torch.cat(
18+
[gt_instances.beziers, pred_instances.beziers], dim=0)
19+
return proposals
20+
21+
22+
@TASK_UTILS.register_module()
23+
class OnlyGTSampler:
24+
25+
def sample(self, gt_instances, pred_instances):
26+
return gt_instances[~gt_instances.ignored]

projects/ABCNet/config/_base_/default_runtime.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
default_scope = 'mmocr'
22
env_cfg = dict(
3-
cudnn_benchmark=True,
3+
cudnn_benchmark=False,
44
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
55
dist_cfg=dict(backend='nccl'),
66
)

projects/ABCNet/config/_base_/schedules/schedule_sgd_500e.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
type='OptimWrapper',
44
optimizer=dict(type='SGD', lr=0.001, momentum=0.9, weight_decay=0.0001),
55
clip_grad=dict(type='value', clip_value=1))
6-
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=500, val_interval=20)
6+
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=500, val_interval=10)
77
val_cfg = dict(type='ValLoop')
88
test_cfg = dict(type='TestLoop')
99
# learning policy

projects/ABCNet/config/abcnet_v2/_base_abcnet-v2_resnet50_bifpn.py

+65-3
Original file line numberDiff line numberDiff line change
@@ -68,21 +68,47 @@
6868
std=0.01,
6969
bias=-4.59511985013459), # -log((1-p)/p) where p=0.01
7070
),
71-
module_loss=None,
71+
module_loss=dict(
72+
type='ABCNetDetModuleLoss',
73+
num_classes=num_classes,
74+
strides=strides,
75+
center_sampling=True,
76+
center_sample_radius=1.5,
77+
bbox_coder=bbox_coder,
78+
norm_on_bbox=norm_on_bbox,
79+
loss_cls=dict(
80+
type='mmdet.FocalLoss',
81+
use_sigmoid=use_sigmoid_cls,
82+
gamma=2.0,
83+
alpha=0.25,
84+
loss_weight=1.0),
85+
loss_bbox=dict(type='mmdet.GIoULoss', loss_weight=1.0),
86+
loss_centerness=dict(
87+
type='mmdet.CrossEntropyLoss',
88+
use_sigmoid=True,
89+
loss_weight=1.0)),
7290
postprocessor=dict(
7391
type='ABCNetDetPostprocessor',
7492
# rescale_fields=['polygons', 'bboxes'],
7593
use_sigmoid_cls=use_sigmoid_cls,
7694
strides=[8, 16, 32, 64, 128],
7795
bbox_coder=dict(type='mmdet.DistancePointBBoxCoder'),
7896
with_bezier=True,
97+
train_cfg=dict(
98+
# rescale_fields=['polygon', 'bboxes', 'bezier'],
99+
nms_pre=1000,
100+
nms=dict(type='nms', iou_threshold=0.4),
101+
score_thr=0.7),
79102
test_cfg=dict(
80103
# rescale_fields=['polygon', 'bboxes', 'bezier'],
81104
nms_pre=1000,
82105
nms=dict(type='nms', iou_threshold=0.4),
83-
score_thr=0.3))),
106+
score_thr=0.4))),
84107
roi_head=dict(
85108
type='RecRoIHead',
109+
inputs_indices=(0, 1, 2),
110+
assigner=dict(type='L1DistanceAssigner'),
111+
sampler=dict(type='ConcatSampler'),
86112
neck=dict(type='CoordinateHead'),
87113
roi_extractor=dict(
88114
type='BezierRoIExtractor',
@@ -97,7 +123,14 @@
97123
decoder=dict(
98124
type='ABCNetRecDecoder',
99125
dictionary=dictionary,
100-
postprocessor=dict(type='AttentionPostprocessor'),
126+
postprocessor=dict(
127+
type='AttentionPostprocessor',
128+
ignore_chars=['padding', 'unknown']),
129+
module_loss=dict(
130+
type='CEModuleLoss',
131+
ignore_first_char=False,
132+
ignore_char=-1,
133+
reduction='mean'),
101134
max_seq_len=25))),
102135
postprocessor=dict(
103136
type='ABCNetPostprocessor',
@@ -120,3 +153,32 @@
120153
type='PackTextDetInputs',
121154
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
122155
]
156+
157+
train_pipeline = [
158+
dict(
159+
type='LoadImageFromFile',
160+
file_client_args=file_client_args,
161+
color_type='color_ignore_orientation'),
162+
dict(
163+
type='LoadOCRAnnotations',
164+
with_polygon=True,
165+
with_bbox=True,
166+
with_label=True,
167+
with_text=True),
168+
dict(type='RemoveIgnored'),
169+
dict(type='RandomCrop', min_side_ratio=0.1),
170+
dict(
171+
type='RandomRotate',
172+
max_angle=30,
173+
pad_with_fixed_color=True,
174+
use_canvas=True),
175+
dict(
176+
type='RandomChoiceResize',
177+
scales=[(980, 2900), (1044, 2900), (1108, 2900), (1172, 2900),
178+
(1236, 2900), (1300, 2900), (1364, 2900), (1428, 2900),
179+
(1492, 2900)],
180+
keep_ratio=True),
181+
dict(
182+
type='PackTextDetInputs',
183+
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
184+
]

0 commit comments

Comments
 (0)