Skip to content

Commit 6bbca60

Browse files
committed
fix circular import problems
1 parent cf95e94 commit 6bbca60

File tree

4 files changed

+16
-12
lines changed

4 files changed

+16
-12
lines changed

rastervision_pytorch_backend/rastervision/pytorch_backend/pytorch_chip_classification.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77
from rastervision.pytorch_backend.pytorch_learner_backend import (
88
PyTorchLearnerSampleWriter, PyTorchLearnerBackend)
99
from rastervision.pytorch_backend.utils import chip_collate_fn_cc
10-
from rastervision.pytorch_learner import (
11-
ClassificationGeoDataConfig, ClassificationSlidingWindowGeoDataset)
10+
from rastervision.pytorch_learner.dataset import (
11+
ClassificationSlidingWindowGeoDataset)
1212
from rastervision.core.data import ChipClassificationLabels
1313

1414
if TYPE_CHECKING:
1515
import numpy as np
1616
from rastervision.core.data import DatasetConfig, Scene
1717
from rastervision.core.rv_pipeline import ChipOptions, PredictOptions
18+
from rastervision.pytorch_learner import ClassificationGeoDataConfig
1819

1920

2021
class PyTorchChipClassificationSampleWriter(PyTorchLearnerSampleWriter):
@@ -89,7 +90,8 @@ def predict_scene(self, scene: 'Scene', predict_options: 'PredictOptions'
8990

9091
def _make_chip_data_config(
9192
self, dataset: 'DatasetConfig',
92-
chip_options: 'ChipOptions') -> ClassificationGeoDataConfig:
93+
chip_options: 'ChipOptions') -> 'ClassificationGeoDataConfig':
94+
from rastervision.pytorch_learner import (ClassificationGeoDataConfig)
9395
data_config = ClassificationGeoDataConfig(
9496
scene_dataset=dataset, sampling=chip_options.sampling)
9597
return data_config

rastervision_pytorch_backend/rastervision/pytorch_backend/pytorch_learner_backend.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,12 @@
1212
from rastervision.core.data.utils.misc import save_img
1313
from rastervision.core.data_sample import DataSample
1414
from rastervision.pytorch_learner.learner import Learner
15-
from rastervision.pytorch_learner.learner_config import DataConfig
1615

1716
if TYPE_CHECKING:
1817
from torch.utils.data import Dataset
1918
from rastervision.core.data import ClassConfig, DatasetConfig, Scene
2019
from rastervision.core.rv_pipeline import RVPipelineConfig, ChipOptions
21-
from rastervision.pytorch_learner.learner_config import LearnerConfig
20+
from rastervision.pytorch_learner import DataConfig, LearnerConfig
2221

2322
SPLITS = ['train', 'valid', 'test']
2423

rastervision_pytorch_backend/rastervision/pytorch_backend/pytorch_object_detection.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
from rastervision.pytorch_backend.utils import chip_collate_fn_od
1313
from rastervision.pytorch_learner.dataset import (
1414
ObjectDetectionSlidingWindowGeoDataset)
15-
from rastervision.pytorch_learner.object_detection_learner_config import (
16-
ObjectDetectionGeoDataConfig)
1715

1816
if TYPE_CHECKING:
1917
from rastervision.core.data import DatasetConfig, Scene
2018
from rastervision.core.rv_pipeline import (ChipOptions,
2119
ObjectDetectionPredictOptions)
2220
from rastervision.pytorch_learner.object_detection_utils import BoxList
21+
from rastervision.pytorch_learner.object_detection_learner_config import (
22+
ObjectDetectionGeoDataConfig)
2323

2424

2525
class PyTorchObjectDetectionSampleWriter(PyTorchLearnerSampleWriter):
@@ -154,7 +154,8 @@ def predict_scene(self, scene: 'Scene',
154154

155155
def _make_chip_data_config(
156156
self, dataset: 'DatasetConfig',
157-
chip_options: 'ChipOptions') -> ObjectDetectionGeoDataConfig:
157+
chip_options: 'ChipOptions') -> 'ObjectDetectionGeoDataConfig':
158+
from rastervision.pytorch_learner import (ObjectDetectionGeoDataConfig)
158159
data_config = ObjectDetectionGeoDataConfig(
159160
scene_dataset=dataset, sampling=chip_options.sampling)
160161
return data_config

rastervision_pytorch_backend/rastervision/pytorch_backend/pytorch_semantic_segmentation.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
from rastervision.pytorch_backend.utils import chip_collate_fn_ss
1313
from rastervision.pytorch_learner.dataset import (
1414
SemanticSegmentationSlidingWindowGeoDataset)
15-
from rastervision.pytorch_learner import SemanticSegmentationGeoDataConfig
1615

1716
if TYPE_CHECKING:
1817
from rastervision.core.data import (DatasetConfig, Scene,
1918
SemanticSegmentationLabelStore)
2019
from rastervision.core.rv_pipeline import (
2120
ChipOptions, SemanticSegmentationPredictOptions)
21+
from rastervision.pytorch_learner import SemanticSegmentationGeoDataConfig
2222

2323

2424
class PyTorchSemanticSegmentationSampleWriter(PyTorchLearnerSampleWriter):
@@ -118,9 +118,11 @@ def predict_scene(self, scene: 'Scene',
118118

119119
return labels
120120

121-
def _make_chip_data_config(
122-
self, dataset: 'DatasetConfig',
123-
chip_options: 'ChipOptions') -> SemanticSegmentationGeoDataConfig:
121+
def _make_chip_data_config(self, dataset: 'DatasetConfig',
122+
chip_options: 'ChipOptions'
123+
) -> 'SemanticSegmentationGeoDataConfig':
124+
from rastervision.pytorch_learner import (
125+
SemanticSegmentationGeoDataConfig)
124126
data_config = SemanticSegmentationGeoDataConfig(
125127
scene_dataset=dataset, sampling=chip_options.sampling)
126128
return data_config

0 commit comments

Comments
 (0)