Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
9f4bcf4
Update testing json
Oct 23, 2025
834a9fa
Add all Samar's code at once
kortukov Oct 23, 2025
add4aa9
Update lcrp submodule
kortukov Oct 23, 2025
e07c3cd
Update code to have correct alert_ref
kortukov Oct 23, 2025
373cddb
Add half type to PIDNet
kortukov Oct 23, 2025
111e77c
Make PIDNet work in half precision
kortukov Oct 23, 2025
219eacf
Do both explanation tasks
kortukov Oct 23, 2025
44cb139
update ploting 3
samarheydari Oct 24, 2025
21cdbbe
Heatmap Optimisation (Layer)
samarheydari Oct 27, 2025
7e593d7
ploting and prototype optimisation
samarheydari Oct 27, 2025
d2c9220
Refactor FloodDataset for improved mask pairing and normalization
samarheydari Nov 13, 2025
d8a6a92
update based on the BRK code
samarheydari Nov 18, 2025
30e6fce
based on the last changes from AUTH
samarheydari Nov 24, 2025
9867f78
Delete .DS_Store
samarheydari Nov 24, 2025
a89e4f6
Ignore h5 files
samarheydari Nov 24, 2025
eec7d6e
Adding the Reference List
samarheydari Nov 24, 2025
c7ab92c
new PCX
samarheydari Nov 24, 2025
0c19ae2
based on BRK trial updated
samarheydari Nov 24, 2025
73a701d
Fix formatting issue in flood_dataset.py by ensuring a newline at the…
samarheydari Jan 13, 2026
187b489
concepts and prototypes
samarheydari Jan 15, 2026
4da59bc
Update prototype and cluster and outlier
samarheydari Jan 15, 2026
5d6a127
add outlier sample and optimize plots
samarheydari Jan 21, 2026
79c1218
update the crp code extraction
Jan 22, 2026
3c1dab8
Change YOLO preprocessing to letterbox and use original images for cl…
Jan 22, 2026
a908119
Change YOLO preprocessing to letterbox and use original images for cl…
Jan 22, 2026
f2f8845
Increase CRP visualization sample size for more reference images
Jan 22, 2026
d891690
add other functions in pcx helper
Jan 22, 2026
f7d317b
add more image extensions to read all images
Jan 22, 2026
a16a7f3
yolo crp generation notebook
Jan 22, 2026
d85e1aa
yolo pcx generation and visualisation notebook
Jan 22, 2026
826b60e
improve the yolo pcx plotting
Jan 22, 2026
f6394ff
add the possibility to use vis_opaque_img_border_v2
Jan 22, 2026
687d0a4
add letterbox tils file for preprocessing
Jan 22, 2026
88c0e55
some fixes for the explanator.py
Jan 22, 2026
608c694
path changes
Jan 22, 2026
91d2854
import fixes
Jan 22, 2026
4837af1
Merge pull request #12 from jawher0001/samar_yolo_updates
Jan 23, 2026
6deeb5b
add yolo_pcx_test.py
Jan 27, 2026
d725b62
fix import issues
Jan 27, 2026
5c3a788
for the feb 12 report
samarheydari Feb 13, 2026
09eac9f
Reduce notebook size: strip outputs and track via LFS
samarheydari Feb 13, 2026
bafbc2b
Recover work before checkout
samarheydari Mar 18, 2026
11d09fd
Update LCRP submodule after recovery
samarheydari Mar 18, 2026
b239894
Merge origin/samar into recovered samar
samarheydari Mar 18, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.ipynb filter=lfs diff=lfs merge=lfs -text
15 changes: 14 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
# Ignore large files and folders
models/
data/
output/
*.pt
*.pth
*.ckpt
# Ignore data, checkpoints, and models
data/
models/
output/
*.pt
models/checkpoints/*
datasets/data/*
__pycache__/
output/

.venv/
.conda/
.conda/*.h5
*.h5
37 changes: 30 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# L-CRP-TEMA
# PCX-TEMA

This repository contains the code for applying the L-CRP method for TEMA project.
This repository contains the code for applying the PCX method for TEMA project.

## Setting up

Expand All @@ -10,17 +10,16 @@ Data and models are available on Google Drive https://drive.google.com/drive/fol

Path for data is `datasets/data`.

#### 1. U-Net model:
- Checkpoint: Checkpoints/unet_flood.pt
- Dataset: Data/General_Flood_v3.zip
- Task: specifically for this model - flood detection
#### 1.PIDNet model:
- Checkpoint: Checkpoints/flood_model.pt
- Dataset: Data/flood_segmentation.zip
- Task: specifically for this model - flood segmentation

#### 2. YOLOv6s6 model:
- Checkpoint: Checkpoints/best_v6s6_ckpt.pt
- Dataset: Data/PersonCarDetectionData
- Task: person and car detection

#### 3. PIDNet model: not yet available


### Build the Docker image
Expand Down Expand Up @@ -124,4 +123,28 @@ python test_post_data.py ImageMetadata --cloud
The test script will send sample image metadata to the application and you should see the processing results in the logs of both the Flask application and the worker process.


### References


LRP (Layer-wise Relevance Propagation):

- original paper: (https://doi.org/10.1371/journal.pone.0130140)
- overview paper: (https://doi.org/10.1007/978-3-030-28954-6_10)
- zennit toolbox: (https://github.com/chr5tphr/zennit)

CRP (Concept Relevance Propagation):

- paper: (https://doi.org/10.1038/s42256-023-00711-8)
- zennit-crp toolbox: (https://github.com/rachtibat/zennit-crp)


L-CRP (Concept Relevance Propagation for Localization Models):

- paper: (https://arxiv.org/pdf/2211.11426)
- L-CRP code: (https://github.com/maxdreyer/L-CRP/tree/main)


PCX (Prototypical Concept-based Explanations):
- paper: (https://arxiv.org/pdf/2311.16681)
- PCX code: (https://github.com/maxdreyer/pcx/tree/main)

Binary file added References/Reference List.pdf
Binary file not shown.
92 changes: 80 additions & 12 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
# Import your existing modules
from src.explanator import Explanator
from src.minio_client import MinIOClient, FHHI_MINIO_BUCKET, NAPLES_MINIO_BUCKET
from common_app_funcs import update_entity, get_bm_id, set_bm_id, update_job_status, get_job_status, get_redis_conn, get_job_queue
from tasks import process_image_task, explanator
from common_app_funcs import update_entity, get_bm_id, set_bm_id,get_uav_id,set_uav_id,get_flight_number,set_flight_number, set_alert_ref_id , get_alert_ref_id , update_job_status, get_job_status, get_redis_conn, get_job_queue
from tasks import process_image_task


# Set up Redis connection and queue
Expand Down Expand Up @@ -232,9 +232,15 @@ def post_data():

# Quick validation check
if entity_type == "Alert":
uav_id = entity["uav_id"]["value"]
flight_number = entity["flight_number"]["value"]
bm_id = entity["bm_id"]["value"]
set_bm_id(bm_id)
msg = f"Received Alert with bm_id and saved to redis: {bm_id}"
alert_ref = entity["alert_ref"]["object"]
set_uav_id(redis_conn, uav_id)
set_flight_number(redis_conn, flight_number)
set_bm_id(redis_conn, bm_id)
set_alert_ref_id(redis_conn, alert_ref)
msg = f"Received Alert with bm_id and alert_ref and uav_id and flight_number saved to redis: {bm_id}, {alert_ref}, uav_id: {uav_id}, flight_number: {flight_number}"
return jsonify({'message': msg}), 200

explanator = get_explanator()
Expand All @@ -245,18 +251,61 @@ def post_data():
return jsonify({'error': err_msg}), 400

current_bm_id = get_bm_id(redis_conn)
current_alert_ref_id = get_alert_ref_id(redis_conn)
current_uav_id = get_uav_id(redis_conn)
current_flight_number = get_flight_number(redis_conn)
app.logger.debug(f"Current alert_ref: {current_alert_ref_id}")
app.logger.debug(f"Current bm_id: {current_bm_id}")
app.logger.debug(f"Current uav_id: {current_uav_id}")
app.logger.debug(f"Current flight_number: {current_flight_number}")

# Extract image information
posted_bm_id = entity["bm_id"]["value"]
if posted_bm_id != current_bm_id:
app.logger.warning(f"Received bm_id: {posted_bm_id} does not match current bm_id: {current_bm_id}")
posted_uav_id = entity["uav_id"]["value"]
if current_uav_id is None:
app.logger.info(f"No cached uav_id; storing value {posted_uav_id}.")
set_uav_id(redis_conn, posted_uav_id)
current_uav_id = posted_uav_id
elif posted_uav_id != current_uav_id:
app.logger.warning(f"Received uav_id: {posted_uav_id} does not match current uav_id: {current_uav_id}; updating stored uav_id.")
set_uav_id(redis_conn, posted_uav_id)
current_uav_id = posted_uav_id

posted_flight_number = entity["flight_number"]["value"]
if current_flight_number is None:
app.logger.info(f"No cached flight_number; storing value {posted_flight_number}.")
set_flight_number(redis_conn, posted_flight_number)
current_flight_number = posted_flight_number
elif posted_flight_number != current_flight_number:
app.logger.warning(f"Received flight_number: {posted_flight_number} does not match current flight_number: {current_flight_number}; updating stored flight_number.")
set_flight_number(redis_conn, posted_flight_number)
current_flight_number = posted_flight_number



posted_bm_id = entity["bm_id"]["value"]
if current_bm_id is None:
app.logger.info(f"No cached bm_id; storing value {posted_bm_id}.")
set_bm_id(redis_conn, posted_bm_id)
current_bm_id = posted_bm_id
elif posted_bm_id != current_bm_id:
app.logger.warning(f"Received bm_id: {posted_bm_id} does not match current bm_id: {current_bm_id}; updating stored bm_id.")
set_bm_id(redis_conn, posted_bm_id)
current_bm_id = posted_bm_id

posted_alert_ref = entity["alert_ref"]["object"]
if current_alert_ref_id is None:
app.logger.info(f"No cached alert_ref; storing value {posted_alert_ref}.")
set_alert_ref_id(redis_conn, posted_alert_ref)
elif posted_alert_ref != current_alert_ref_id:
app.logger.warning(f"Received alert_ref: {posted_alert_ref} does not match current alert_ref: {current_alert_ref_id}; updating stored alert_ref.")
set_alert_ref_id(redis_conn, posted_alert_ref)
current_alert_ref_id = posted_alert_ref
src_image_filename = entity["filename"]["value"]
src_image_bucket = entity["bucket"]["value"]

# Submit tasks for both PersonVehicleDetection and FloodSegmentation
# entities_to_explain = ['PersonVehicleDetection']
# entities_to_explain = ['FloodSegmentation']
entities_to_explain = ['FloodSegmentation', 'PersonVehicleDetection']

task_ids = []
Expand All @@ -272,6 +321,10 @@ def post_data():
'entity_type': entity_type,
'src_image_bucket': src_image_bucket,
'minio_filename': src_image_filename,
'uav_id': posted_uav_id,
'flight_number': posted_flight_number,
'bm_id': posted_bm_id,
'alert_ref': posted_alert_ref
}
update_job_status(redis_conn, task_id, job_status)

Expand All @@ -282,6 +335,10 @@ def post_data():
src_image_bucket,
src_image_filename,
task_id,
uav_id=posted_uav_id,
flight_number=posted_flight_number,
bm_id=posted_bm_id,
alert_ref=posted_alert_ref,
job_timeout='12h' # Set an appropriate timeout
)
task_ids.append(task_id)
Expand Down Expand Up @@ -322,15 +379,26 @@ def post_data_old():
app.logger.debug(f"Received outer entity type: {outer_entity_type} instead of Notification")

if entity_type == "Alert":
uav_id = entity["uav_id"]["value"]
flight_number = entity["flight_number"]["value"]
bm_id = entity["bm_id"]["value"]
set_bm_id(bm_id)
msg = f"Received Alert with bm_id: {bm_id}"
alert_ref = entity["alert_ref"]["value"]
set_alert_ref_id(redis_conn, alert_ref)
set_bm_id(redis_conn, bm_id)
set_uav_id(redis_conn, uav_id)
set_flight_number(redis_conn, flight_number)
msg = f"Received Alert with bm_id: {bm_id} and alert_ref: {alert_ref} saved to redis"
return jsonify({'message': msg}), 200

bm_id = get_bm_id()
bm_id = get_bm_id(redis_conn)
uav_id = get_uav_id(redis_conn)
flight_number = get_flight_number(redis_conn)
alert_ref = get_alert_ref_id(redis_conn)
app.logger.debug(f"Current alert_ref: {alert_ref}")

app.logger.debug(f"Current bm_id: {bm_id}")

app.logger.debug(f"Current uav_id: {uav_id}")
app.logger.debug(f"Current flight_number: {flight_number}")
explanator = get_explanator()

if entity_type not in explanator.VALID_ENTITY_TYPES:
Expand Down Expand Up @@ -516,4 +584,4 @@ def subscribe_to_context_broker():
# Run the Flask application
if __name__ == '__main__':
debug = os.environ.get('DEBUG', '').lower() in ('true', '1')
app.run(host='0.0.0.0', port=int(os.environ.get('PORT', 8080)), debug=debug)
app.run(host='0.0.0.0', port=int(os.environ.get('PORT', 8080)), debug=debug)
24 changes: 24 additions & 0 deletions common_app_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,30 @@ def get_bm_id(redis_conn):
bm_id = redis_conn.get('current_bm_id')
return bm_id.decode('utf-8') if bm_id else None

def set_uav_id(redis_conn, uav_id):
redis_conn.set('current_uav_id', uav_id)

def get_uav_id(redis_conn):
uav_id = redis_conn.get('current_uav_id')
return uav_id.decode('utf-8') if uav_id else None


def set_flight_number(redis_conn, flight_number):
redis_conn.set('current_flight_number', flight_number)

def get_flight_number(redis_conn):
flight_number = redis_conn.get('current_flight_number')
return flight_number.decode('utf-8') if flight_number else None



def set_alert_ref_id(redis_conn, alert_ref):
redis_conn.set('current_alert_ref_id', alert_ref)

def get_alert_ref_id(redis_conn):
alert_ref = redis_conn.get('current_alert_ref_id')
return alert_ref.decode('utf-8') if alert_ref else None

def update_job_status(redis_conn, task_id, status_data):
redis_conn.set(f"job_status:{task_id}", json.dumps(status_data))

Expand Down
11 changes: 11 additions & 0 deletions configs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Ke Sun (sunk@mail.ustc.edu.cn)
# ------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from .default import _C as config
from .default import update_config
105 changes: 105 additions & 0 deletions configs/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# ------------------------------------------------------------------------------
# Modified based on https://github.com/HRNet/HRNet-Semantic-Segmentation
# ------------------------------------------------------------------------------

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from yacs.config import CfgNode as CN


_C = CN()

_C.OUTPUT_DIR = ''
_C.LOG_DIR = ''
_C.GPUS = (0,)
_C.WORKERS = 4
_C.PRINT_FREQ = 20
_C.AUTO_RESUME = False
_C.PIN_MEMORY = True

# Cudnn related params
_C.CUDNN = CN()
_C.CUDNN.BENCHMARK = True
_C.CUDNN.DETERMINISTIC = False
_C.CUDNN.ENABLED = True

# common params for NETWORK
_C.MODEL = CN()
_C.MODEL.NAME = 'pidnet_s'
_C.MODEL.PRETRAINED = 'pretrained_models/imagenet/PIDNet_S_ImageNet.pth.tar'
_C.MODEL.ALIGN_CORNERS = True
_C.MODEL.NUM_OUTPUTS = 2


_C.LOSS = CN()
_C.LOSS.USE_OHEM = True
_C.LOSS.OHEMTHRES = 0.9
_C.LOSS.OHEMKEEP = 100000
_C.LOSS.CLASS_BALANCE = False
_C.LOSS.BALANCE_WEIGHTS = [0.5, 0.5]
_C.LOSS.SB_WEIGHTS = 0.5

# DATASET related params
_C.DATASET = CN()
_C.DATASET.ROOT = 'data/'
_C.DATASET.DATASET = 'cityscapes'
_C.DATASET.NUM_CLASSES = 19
_C.DATASET.TRAIN_SET = 'list/cityscapes/train.lst'
_C.DATASET.EXTRA_TRAIN_SET = ''
_C.DATASET.TEST_SET = 'list/cityscapes/val.lst'

# training
_C.FINETUNE = False
_C.TRAIN = CN()
_C.TRAIN.IMAGE_SIZE = [1024, 1024] # width * height
_C.TRAIN.BASE_SIZE = 2048
_C.TRAIN.FLIP = True
_C.TRAIN.MULTI_SCALE = True
_C.TRAIN.SCALE_FACTOR = 16

_C.TRAIN.LR = 0.01
_C.TRAIN.EXTRA_LR = 0.001

_C.TRAIN.OPTIMIZER = 'sgd'
_C.TRAIN.MOMENTUM = 0.9
_C.TRAIN.WD = 0.0001
_C.TRAIN.NESTEROV = False
_C.TRAIN.IGNORE_LABEL = -1

_C.TRAIN.BEGIN_EPOCH = 0
_C.TRAIN.END_EPOCH = 484
_C.TRAIN.EXTRA_EPOCH = 0

_C.TRAIN.RESUME = False

_C.TRAIN.BATCH_SIZE_PER_GPU = 32
_C.TRAIN.SHUFFLE = True

# testing
_C.TEST = CN()
_C.TEST.IMAGE_SIZE = [2048, 1024] # width * height
_C.TEST.BASE_SIZE = 2048
_C.TEST.BATCH_SIZE_PER_GPU = 32
_C.TEST.MODEL_FILE = ''
_C.TEST.FLIP_TEST = False
_C.TEST.MULTI_SCALE = False

_C.TEST.OUTPUT_INDEX = -1


def update_config(cfg, args):
cfg.defrost()

cfg.merge_from_file(args.cfg)
cfg.merge_from_list(args.opts)

cfg.freeze()


if __name__ == '__main__':
import sys
with open(sys.argv[1], 'w') as f:
print(_C, file=f)

Loading