Skip to content

Commit 8deb190

Browse files
fix : prediction post processing for offline workers
1 parent 547d4cb commit 8deb190

3 files changed

Lines changed: 27 additions & 8 deletions

File tree

backend/core/tasks.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -224,14 +224,17 @@ def _train_yolo(self, output_path):
224224
}
225225

226226
def _train_ramp(self, output_path):
227-
os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=-1'
227+
os.environ["TF_XLA_FLAGS"] = "--tf_xla_auto_jit=-1"
228228
import tensorflow as tf
229229
from hot_fair_utilities import preprocess
230230
from hot_fair_utilities.training.ramp import train
231-
tf.config.optimizer.set_jit(False) # Disable XLA for RAMP training , bug in tensorflow 2.9.*
232-
231+
232+
tf.config.optimizer.set_jit(
233+
False
234+
) # Disable XLA for RAMP training , bug in tensorflow 2.9.*
235+
233236
# os.environ["TF_XLA_FLAGS"] = "--tf_xla_auto_jit=-1"
234-
237+
235238
setup_ramp()
236239
(
237240
inst,
@@ -523,14 +526,26 @@ def predict_area(prediction_request_id):
523526
predict_area.request.id,
524527
)
525528
inst.save()
526-
inst.config["geojson"] = inst.geom
527529
params = PredictionRequest(**inst.config)
530+
predictions = asyncio.run(
531+
predict(
532+
geojson=inst.geom.geojson,
533+
model_path=params.checkpoint,
534+
zoom_level=params.zoom_level,
535+
tms_url=params.source,
536+
confidence=params.confidence / 100,
537+
tolerance=params.tolerance,
538+
area_threshold=params.area_threshold,
539+
orthogonalize=params.use_josm_q,
540+
vectorization_algorithm=params.vectorization_algorithm,
541+
)
542+
)
528543

529-
predictions = asyncio.run(predict(**params))
530-
out = os.path.join(settings.PREDICTION_WORKSPACE, inst.id)
544+
out = os.path.join(settings.PREDICTION_WORKSPACE, str(inst.id))
545+
os.makedirs(out, exist_ok=True)
531546
write_json(
532547
os.path.join(out, "aois.geojson"),
533-
inst.geom,
548+
inst.geom.geojson,
534549
)
535550
write_json(
536551
os.path.join(out, "labels.geojson"),

backend/core/views.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from rest_framework.response import Response
4545
from rest_framework.views import APIView
4646
from rest_framework.viewsets import ReadOnlyModelViewSet
47+
from rest_framework_gis.fields import GeometryField
4748
from rest_framework_gis.filters import InBBoxFilter, TMSTileFilter
4849

4950
from .models import (
@@ -1007,6 +1008,8 @@ def post(self, request, training_id, format=None):
10071008

10081009

10091010
class PredictionSerializer(serializers.ModelSerializer):
1011+
geom = GeometryField()
1012+
10101013
class Meta:
10111014
model = Prediction
10121015
fields = "__all__"

backend/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,5 @@ yolo-workers = [
5555
]
5656
prediction-workers = [
5757
"fairpredictor>=0.3.4",
58+
"tippecanoe>=2.72.0",
5859
]

0 commit comments

Comments
 (0)