Skip to content

Commit 70a9eb1

Browse files
committed
Improve style.
1 parent abff95c commit 70a9eb1

File tree

14 files changed

+122
-84
lines changed

14 files changed

+122
-84
lines changed

openpmcvl/granular/models/network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
import torch.nn as nn
2+
from torch import nn
33
from torch.utils.model_zoo import load_url as load_state_dict_from_url
44

55

openpmcvl/granular/models/process.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import division
22

3-
import torch
43
import cv2
54
import numpy as np
5+
import torch
66

77

88
def label2yolobox(labels, info_img, maxsize, lrflip):
@@ -22,7 +22,8 @@ class (float): class index.
2222
maxsize (int): target image size after pre-processing
2323
lrflip (bool): horizontal flip flag
2424
25-
Returns:
25+
Returns
26+
-------
2627
labels:label data whose size is :math:`(N, 5)`.
2728
Each label consists of [class, xc, yc, w, h] where
2829
class (float): class index.
@@ -77,7 +78,9 @@ def nms(bbox, thresh, score=None, limit=None):
7778
limit (int): The upper bound of the number of the output bounding
7879
boxes. If it is not specified, this method selects as many
7980
bounding boxes as possible.
80-
Returns:
81+
82+
Returns
83+
-------
8184
array:
8285
An array with indices of bounding boxes that are selected. \
8386
They are sorted by the scores of bounding boxes in descending \
@@ -87,7 +90,6 @@ def nms(bbox, thresh, score=None, limit=None):
8790
8891
from: https://github.com/chainer/chainercv
8992
"""
90-
9193
if len(bbox) == 0:
9294
return np.zeros((0,), dtype=np.int32)
9395

@@ -135,7 +137,8 @@ def postprocess(prediction, dtype, conf_thre=0.7, nms_thre=0.45):
135137
nms_thre (float):
136138
IoU threshold of non-max suppression ranging from 0 to 1.
137139
138-
Returns:
140+
Returns
141+
-------
139142
output (list of torch tensor):
140143
141144
"""
@@ -203,7 +206,8 @@ def preprocess(img, imgsize, jitter, random_placing=False):
203206
jitter (float): amplitude of jitter for resizing
204207
random_placing (bool): if True, place the image at random position
205208
206-
Returns:
209+
Returns
210+
-------
207211
img (numpy.ndarray): input image whose shape is :math:`(C, imgsize, imgsize)`.
208212
Values range from 0 to 1.
209213
info_img : tuple of h, w, nh, nw, dx, dy.

openpmcvl/granular/models/subfigure_detector.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
import math
2+
13
import torch
4+
from einops import repeat
25
from pytorch_pretrained_bert.modeling import BertModel
36
from torch import nn
47
from torchvision import models
5-
import math
68

79
from openpmcvl.granular.models.transformer_module import *
8-
from einops import repeat
910

1011

1112
class FigCap_Former(nn.Module):
@@ -116,7 +117,8 @@ def forward(self, images, texts):
116117
images (compound figure): shape (bs, c, h, w)
117118
texts (caption tokens): shape (bs, max_length_in_this_batch)
118119
119-
Returns:
120+
Returns
121+
-------
120122
output_det_class: tensor (bs, query_num, 1), 0~1 indicate subfigure or no-subfigure
121123
output_box: tensor (bs, query_num, 4), prediction of [cx, cy, w, h]
122124
similarity: tensor (bs, query_num, caption_length), 0~1 indicate belong or not belong to the subfigure

openpmcvl/granular/models/subfigure_ocr.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
import os
2-
import yaml
2+
3+
import cv2
4+
import numpy as np
35
import torch
6+
import torch.nn.functional as F
7+
import yaml
8+
from PIL import Image
49
from skimage import io
5-
import numpy as np
6-
import cv2
710
from torch.autograd import Variable
8-
from PIL import Image
9-
import torch.nn.functional as F
1011

11-
from openpmcvl.granular.models.yolov3 import YOLOv3
1212
from openpmcvl.granular.models.network import resnet152
13-
from openpmcvl.granular.models.process import preprocess, postprocess, yolobox2label
13+
from openpmcvl.granular.models.process import postprocess, preprocess, yolobox2label
14+
from openpmcvl.granular.models.yolov3 import YOLOv3
1415

1516

1617
class classifier:
@@ -44,7 +45,7 @@ def __init__(self):
4445
self.text_recognition_model.eval()
4546

4647
def load_model_from_checkpoint(self, model, model_name):
47-
"""load checkpoint weights into model"""
48+
"""Load checkpoint weights into model"""
4849
checkpoints_path = os.path.join(self.current_dir, "..", "checkpoints")
4950
checkpoint = os.path.join(checkpoints_path, model_name)
5051
model.load_state_dict(torch.load(checkpoint))
@@ -61,7 +62,6 @@ def detect_subfigure_boundaries(self, figure_path):
6162
subfigure_info (list of lists): Each inner list is
6263
x1, y1, x2, y2, confidence
6364
"""
64-
6565
## Preprocess the figure for the models
6666
img = io.imread(figure_path)
6767
if len(np.shape(img)) == 2:

openpmcvl/granular/models/transformer_module.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
import copy
2+
13
import torch
2-
from torch import nn
34
import torch.nn.functional as F
4-
import copy
5+
from torch import nn
56

67

78
class MultiHeadAttention(nn.Module):

openpmcvl/granular/models/yolo_layer.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
import torch
2-
import torch.nn as nn
3-
import numpy as np
41
import warnings
2+
3+
import numpy as np
4+
import torch
5+
from torch import nn
6+
57
from openpmcvl.granular.models.network import resnet152
68
from openpmcvl.granular.models.process import preprocess
79

@@ -18,7 +20,9 @@ def bboxes_iou(bboxes_a, bboxes_b, xyxy=True):
1820
bbox_b (array): An array similar to :obj:`bbox_a`,
1921
whose shape is :math:`(K, 4)`.
2022
The dtype should be :obj:`numpy.float32`.
21-
Returns:
23+
24+
Returns
25+
-------
2226
array:
2327
An array whose shape is :math:`(N, K)`. \
2428
An element at index :math:`(n, k)` contains IoUs between \
@@ -72,7 +76,6 @@ def __init__(self, config_model, layer_no, in_ch, ignore_thre=0.7):
7276
in_ch (int): number of input channels.
7377
ignore_thre (float): threshold of IoU above which objectness training is ignored.
7478
"""
75-
7679
super(YOLOLayer, self).__init__()
7780
strides = [32, 16, 8] # fixed
7881
self.anchors = config_model["ANCHORS"]
@@ -111,7 +114,9 @@ def forward(self, xin, compound_labels=None):
111114
class (float): class index.
112115
xc, yc (float) : center of bbox whose values range from 0 to 1.
113116
w, h (float) : size of bbox whose values range from 0 to 1.
114-
Returns:
117+
118+
Returns
119+
-------
115120
loss (torch.Tensor): total loss - the target of backprop.
116121
loss_xy (torch.Tensor): x, y loss - calculated by binary cross entropy (BCE) \
117122
with boxsize-dependent weights.
@@ -319,7 +324,6 @@ def __init__(self, config_model, layer_no, in_ch, ignore_thre=0.7):
319324
in_ch (int): number of input channels.
320325
ignore_thre (float): threshold of IoU above which objectness training is ignored.
321326
"""
322-
323327
super(YOLOimgLayer, self).__init__()
324328
strides = [32, 16, 8] # fixed
325329
self.anchors = config_model["ANCHORS"]
@@ -356,7 +360,9 @@ def forward(self, xin, all_labels=None):
356360
class (float): class index.
357361
xc, yc (float) : center of bbox whose values range from 0 to 1.
358362
w, h (float) : size of bbox whose values range from 0 to 1.
359-
Returns:
363+
364+
Returns
365+
-------
360366
loss (torch.Tensor): total loss - the target of backprop.
361367
loss_xy (torch.Tensor): x, y loss - calculated by binary cross entropy (BCE) \
362368
with boxsize-dependent weights.

openpmcvl/granular/models/yolov3.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
from collections import defaultdict
2+
13
import torch
2-
import torch.nn as nn
4+
from torch import nn
35

4-
from collections import defaultdict
5-
from openpmcvl.granular.models.yolo_layer import YOLOLayer, YOLOimgLayer
6+
from openpmcvl.granular.models.yolo_layer import YOLOimgLayer, YOLOLayer
67

78

89
def add_conv(in_ch, out_ch, ksize, stride):
@@ -13,7 +14,9 @@ def add_conv(in_ch, out_ch, ksize, stride):
1314
out_ch (int): number of output channels of the convolution layer.
1415
ksize (int): kernel size of the convolution layer.
1516
stride (int): stride of the convolution layer.
16-
Returns:
17+
18+
Returns
19+
-------
1720
stage (Sequential) : Sequential layers composing a convolution block.
1821
"""
1922
stage = nn.Sequential()
@@ -71,10 +74,11 @@ def create_yolov3_modules(config_model, ignore_thre):
7174
config_model (dict): model configuration.
7275
See YOLOLayer class for details.
7376
ignore_thre (float): used in YOLOLayer.
74-
Returns:
77+
78+
Returns
79+
-------
7580
mlist (ModuleList): YOLOv3 module list.
7681
"""
77-
7882
# DarkNet53
7983
mlist = nn.ModuleList()
8084
mlist.append(add_conv(in_ch=3, out_ch=32, ksize=3, stride=1))
@@ -153,7 +157,8 @@ def forward(self, x, targets=None):
153157
where N, C are batchsize and num. of channels.
154158
targets (torch.Tensor) : label array whose shape is :math:`(N, 50, 5)`
155159
156-
Returns:
160+
Returns
161+
-------
157162
training:
158163
output (torch.Tensor): loss tensor for backpropagation.
159164
test:
@@ -200,10 +205,11 @@ def create_yolov3img_modules(config_model, ignore_thre):
200205
config_model (dict): model configuration.
201206
See YOLOLayer class for details.
202207
ignore_thre (float): used in YOLOLayer.
203-
Returns:
208+
209+
Returns
210+
-------
204211
mlist (ModuleList): YOLOv3 module list.
205212
"""
206-
207213
# DarkNet53
208214
mlist = nn.ModuleList()
209215
mlist.append(add_conv(in_ch=4, out_ch=32, ksize=3, stride=1))
@@ -282,7 +288,8 @@ def forward(self, x, targets=None):
282288
where N, C are batchsize and num. of channels.
283289
targets (torch.Tensor) : label array whose shape is :math:`(N, 50, 5)`
284290
285-
Returns:
291+
Returns
292+
-------
286293
training:
287294
output (torch.Tensor): loss tensor for backpropagation.
288295
test:

openpmcvl/granular/pipeline/align.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
import os
21
import argparse
2+
import os
33
from typing import Dict
44

55
from tqdm import tqdm
6+
67
from openpmcvl.granular.models.subfigure_ocr import classifier
78
from openpmcvl.granular.pipeline.utils import load_dataset, save_jsonl
89

@@ -15,7 +16,8 @@ def process_subfigure(model: classifier, subfig_data: Dict) -> Dict:
1516
model (classifier): Initialized OCR model
1617
subfig_data (Dict): Dictionary containing subfigure data
1718
18-
Returns:
19+
Returns
20+
-------
1921
Dict: Updated subfigure data with OCR results
2022
"""
2123
if "subfig_path" not in subfig_data:

openpmcvl/granular/pipeline/classify.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
import argparse
2-
from PIL import Image
32
from typing import Any, Dict, List
43

54
import torch
6-
import torch.nn as nn
5+
from PIL import Image
6+
from torch import nn
77
from torch.utils.data import DataLoader
88
from torchvision import models, transforms
99
from tqdm import tqdm
1010

11-
from openpmcvl.granular.pipeline.utils import load_dataset, save_jsonl
1211
from openpmcvl.granular.dataset.dataset import SubfigureDataset
12+
from openpmcvl.granular.pipeline.utils import load_dataset, save_jsonl
13+
1314

1415
MEDICAL_CLASS = 15
1516
CLASSIFICATION_THRESHOLD = 4
@@ -23,7 +24,8 @@ def load_classification_model(model_path: str, device: torch.device) -> nn.Modul
2324
model_path (str): Path to the classification model checkpoint
2425
device (torch.device): Device to use for processing
2526
26-
Returns:
27+
Returns
28+
-------
2729
nn.Module: Loaded classification model
2830
"""
2931
fig_model = models.resnext101_32x8d()
@@ -55,7 +57,9 @@ def classify_dataset(
5557
device (torch.device): Device to use for processing.
5658
output_file (str): Path to save the updated JSONL file with classification results.
5759
num_workers (int): Number of workers for processing.
58-
Returns:
60+
61+
Returns
62+
-------
5963
None
6064
"""
6165
mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
@@ -118,7 +122,8 @@ def main(args: argparse.Namespace) -> None:
118122
- batch_size (int): Batch size for processing
119123
- output_file (str): Path to save the JSONL file with classification results
120124
121-
Returns:
125+
Returns
126+
-------
122127
None
123128
"""
124129
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

0 commit comments

Comments
 (0)