Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 27 additions & 10 deletions torchxrayvision/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
"weights_url": 'https://github.com/mlmed/torchxrayvision/releases/download/v1/nih-pc-chex-mimic_ch-google-openi-kaggle-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt',
"labels": ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening', 'Cardiomegaly', 'Nodule', 'Mass', 'Hernia', 'Lung Lesion', 'Fracture', 'Lung Opacity', 'Enlarged Cardiomediastinum'],
"op_threshs": [0.07422872, 0.038290843, 0.09814756, 0.0098118475, 0.023601074, 0.0022490358, 0.010060724, 0.103246614, 0.056810737, 0.026791653, 0.050318155, 0.023985857, 0.01939503, 0.042889766, 0.053369623, 0.035975814, 0.20204692, 0.05015312],
"ppv80_thres": [0.72715247, 0.8885005, 0.92493945, 0.6527224, 0.68707734, 0.46127197, 0.7272054, 0.6127343, 0.9878492, 0.61979693, 0.66309816, 0.7853459, 0.930661, 0.93645346, 0.6788558, 0.6547198, 0.61614525, 0.8489876]
"ppv80_thres": [0.72715247, 0.8885005, 0.92493945, 0.6527224, 0.68707734, 0.46127197, 0.7272054, 0.6127343, 0.9878492, 0.61979693, 0.66309816, 0.7853459, 0.930661, 0.93645346, 0.6788558, 0.6547198, 0.61614525, 0.8489876],
"input_resolution": 224,
}
model_urls['densenet121-res224-all'] = model_urls['all']

Expand All @@ -32,43 +33,49 @@
"weights_url": 'https://github.com/mlmed/torchxrayvision/releases/download/v1/nih-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt',
"labels": ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening', 'Cardiomegaly', 'Nodule', 'Mass', 'Hernia', '', '', '', ''],
"op_threshs": [0.039117552, 0.0034529066, 0.11396341, 0.0057298196, 0.00045666535, 0.0018880932, 0.012037827, 0.038744126, 0.0037213727, 0.014730946, 0.016149804, 0.054241467, 0.037198864, 0.0004403434, np.nan, np.nan, np.nan, np.nan],
"input_resolution": 224,
}
model_urls['densenet121-res224-nih'] = model_urls['nih']


model_urls['pc'] = {
"weights_url": 'https://github.com/mlmed/torchxrayvision/releases/download/v1/pc-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt',
"labels": ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening', 'Cardiomegaly', 'Nodule', 'Mass', 'Hernia', '', 'Fracture', '', ''],
"op_threshs": [0.031012505, 0.013347598, 0.081435576, 0.001262615, 0.002587246, 0.0035944257, 0.0023071, 0.055412333, 0.044385884, 0.042766232, 0.043258056, 0.037629247, 0.005658899, 0.0091741895, np.nan, 0.026507627, np.nan, np.nan]
"op_threshs": [0.031012505, 0.013347598, 0.081435576, 0.001262615, 0.002587246, 0.0035944257, 0.0023071, 0.055412333, 0.044385884, 0.042766232, 0.043258056, 0.037629247, 0.005658899, 0.0091741895, np.nan, 0.026507627, np.nan, np.nan],
"input_resolution": 224,
}
model_urls['densenet121-res224-pc'] = model_urls['pc']

model_urls['chex'] = {
"weights_url": 'https://github.com/mlmed/torchxrayvision/releases/download/v1/chex-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt',
"labels": ['Atelectasis', 'Consolidation', '', 'Pneumothorax', 'Edema', '', '', 'Effusion', 'Pneumonia', '', 'Cardiomegaly', '', '', '', 'Lung Lesion', 'Fracture', 'Lung Opacity', 'Enlarged Cardiomediastinum'],
"op_threshs": [0.1988969, 0.05710573, np.nan, 0.0531293, 0.1435217, np.nan, np.nan, 0.27212676, 0.07749717, np.nan, 0.19712369, np.nan, np.nan, np.nan, 0.09932402, 0.09273402, 0.3270967, 0.10888247],
"input_resolution": 224,
}
model_urls['densenet121-res224-chex'] = model_urls['chex']


model_urls['rsna'] = {
"weights_url": 'https://github.com/mlmed/torchxrayvision/releases/download/v1/kaggle-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt',
"labels": ['', '', '', '', '', '', '', '', 'Pneumonia', '', '', '', '', '', '', '', 'Lung Opacity', ''],
"op_threshs": [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, 0.13486601, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, 0.13511065, np.nan]
"op_threshs": [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, 0.13486601, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, 0.13511065, np.nan],
"input_resolution": 224,
}
model_urls['densenet121-res224-rsna'] = model_urls['rsna']

model_urls['mimic_nb'] = {
"weights_url": 'https://github.com/mlmed/torchxrayvision/releases/download/v1/mimic_nb-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt',
"labels": ['Atelectasis', 'Consolidation', '', 'Pneumothorax', 'Edema', '', '', 'Effusion', 'Pneumonia', '', 'Cardiomegaly', '', '', '', 'Lung Lesion', 'Fracture', 'Lung Opacity', 'Enlarged Cardiomediastinum'],
"op_threshs": [0.08558747, 0.011884617, np.nan, 0.0040595434, 0.010733786, np.nan, np.nan, 0.118761964, 0.022924708, np.nan, 0.06358637, np.nan, np.nan, np.nan, 0.022143636, 0.017476924, 0.1258702, 0.014020768],
"input_resolution": 224,
}
model_urls['densenet121-res224-mimic_nb'] = model_urls['mimic_nb']

model_urls['mimic_ch'] = {
"weights_url": 'https://github.com/mlmed/torchxrayvision/releases/download/v1/mimic_ch-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt',
"labels": ['Atelectasis', 'Consolidation', '', 'Pneumothorax', 'Edema', '', '', 'Effusion', 'Pneumonia', '', 'Cardiomegaly', '', '', '', 'Lung Lesion', 'Fracture', 'Lung Opacity', 'Enlarged Cardiomediastinum'],
"op_threshs": [0.09121389, 0.010573786, np.nan, 0.005023008, 0.003698257, np.nan, np.nan, 0.08001232, 0.037242252, np.nan, 0.05006329, np.nan, np.nan, np.nan, 0.019866971, 0.03823637, 0.11303808, 0.0069147074],
"input_resolution": 224,
}
model_urls['densenet121-res224-mimic_ch'] = model_urls['mimic_ch']

Expand All @@ -77,7 +84,8 @@
"weights_url": 'https://github.com/mlmed/torchxrayvision/releases/download/v1/pc-nih-rsna-siim-vin-resnet50-test512-e400-state.pt',
"labels": ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening', 'Cardiomegaly', 'Nodule', 'Mass', 'Hernia', 'Lung Lesion', 'Fracture', 'Lung Opacity', 'Enlarged Cardiomediastinum'],
"op_threshs": [0.51570356, 0.50444704, 0.53787947, 0.50723547, 0.5025118, 0.5035252, 0.5038076, 0.51862943, 0.5078151, 0.50724894, 0.5056339, 0.510706, 0.5053923, 0.5020846, np.nan, 0.5080557, 0.5138526, np.nan],
"ppv80_thres": [0.690908, 0.720028, 0.7303882, 0.7235838, 0.6787441, 0.7304924, 0.73105824, 0.6839408, 0.7241559, 0.7219969, 0.6346738, 0.72764945, 0.7285066, 0.5735704, np.nan, 0.69684714, 0.7135549, np.nan]
"ppv80_thres": [0.690908, 0.720028, 0.7303882, 0.7235838, 0.6787441, 0.7304924, 0.73105824, 0.6839408, 0.7241559, 0.7219969, 0.6346738, 0.72764945, 0.7285066, 0.5735704, np.nan, 0.69684714, 0.7135549, np.nan],
"input_resolution": 512,
}

# Just created for documentation
Expand Down Expand Up @@ -313,24 +321,29 @@ def __init__(self,
if "op_threshs" in model_urls[weights]:
self.op_threshs = torch.tensor(model_urls[weights]["op_threshs"])

if "input_resolution" in model_urls[weights]:
self.input_resolution = model_urls[weights]["input_resolution"]

def __repr__(self):
if self.weights is not None:
return "XRV-DenseNet121-{}".format(self.weights)
else:
return "XRV-DenseNet"

def features2(self, x):
x = utils.fix_resolution(x, 224, self)
utils.warn_normalization(x)
if hasattr(self, 'input_resolution'):
x = utils.fix_resolution(x, self.input_resolution, self)
utils.warn_normalization(x)

features = self.features(x)
out = F.relu(features, inplace=True)
out = F.adaptive_avg_pool2d(out, (1, 1)).view(features.size(0), -1)
return out

def forward(self, x):
x = utils.fix_resolution(x, 224, self)
utils.warn_normalization(x)
if hasattr(self, 'input_resolution'):
x = utils.fix_resolution(x, self.input_resolution, self)
utils.warn_normalization(x)

features = self.features2(x)
out = self.classifier(features)
Expand Down Expand Up @@ -419,6 +432,9 @@ def __init__(self, weights: str = None, apply_sigmoid: bool = False, cache_dir:
if "op_threshs" in model_urls[weights]:
self.register_buffer('op_threshs', torch.tensor(model_urls[weights]["op_threshs"]))

if "input_resolution" in model_urls[weights]:
self.input_resolution = model_urls[weights]["input_resolution"]

self.eval()

def __repr__(self):
Expand Down Expand Up @@ -446,8 +462,9 @@ def features(self, x):
return x

def forward(self, x):
x = utils.fix_resolution(x, 512, self)
utils.warn_normalization(x)
if hasattr(self, 'input_resolution'):
x = utils.fix_resolution(x, self.input_resolution, self)
utils.warn_normalization(x)

out = self.model(x)

Expand Down