Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@ Types of changes
* "Security" in case of vulnerabilities.
-->

## [1.5.1]

### Added

### Changed

### Fixed

- Stop enabling progress bars when they are disabled

## [1.5.0]

### Added
Expand Down
9 changes: 6 additions & 3 deletions span_marker/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch
import torch.nn.functional as F
from datasets import Dataset, disable_progress_bar, enable_progress_bar
from datasets import Dataset, disable_progress_bar, enable_progress_bar, is_progress_bar_enabled
from packaging.version import Version, parse
from torch import device, nn
from tqdm.autonotebook import trange
Expand Down Expand Up @@ -489,7 +489,9 @@ def predict(
"inference without document-level context may cause decreased performance."
)

if not show_progress_bar:
progress_bars_enabled = is_progress_bar_enabled()
if not show_progress_bar and progress_bars_enabled:
# disable progress bars if they are enabled
disable_progress_bar()
dataset = dataset.map(
Trainer.spread_sample,
Expand All @@ -500,7 +502,8 @@ def predict(
"marker_max_length": self.config.marker_max_length,
},
)
if not show_progress_bar:
if not show_progress_bar and progress_bars_enabled:
# re-enable progress bars now if they were originally enabled
enable_progress_bar()
for batch_start_idx in trange(0, len(dataset), batch_size, leave=True, disable=not show_progress_bar):
batch = dataset.select(range(batch_start_idx, min(len(dataset), batch_start_idx + batch_size)))
Expand Down