Skip to content

Implement DINO object detection architecture in Keras 3#358

Open
mohamedamrfaroukebrahim wants to merge 7 commits intooarriaga:paz-jaxfrom
mohamedamrfaroukebrahim:paz-jax
Open

Implement DINO object detection architecture in Keras 3#358
mohamedamrfaroukebrahim wants to merge 7 commits intooarriaga:paz-jaxfrom
mohamedamrfaroukebrahim:paz-jax

Conversation

@mohamedamrfaroukebrahim

No description provided.

This commit introduces a fully functional, backend-agnostic Keras 3 implementation of the DINO (LW-DETR) object detection model. It includes the complete pipeline from the Vision Transformer (ViT) backbone wrapper to the Transformer decoder, segmentation heads, Hungarian matching logic, and training utilities. A comprehensive testing suite is included to verify strict numerical parity with the original PyTorch implementation.

**Core Implementation Details:**

* **Architecture:** Implemented hybrid encoder-decoder structure with MultiScale Projectors and DINOv2 backbone integration.
* **Panoptic Support:** Added Segmentation Heads with uncertainty-based point sampling.
* **Training:** Implemented Hungarian Matching, Focal Loss, GIoU loss, and layer-wise learning rate decay for ViT.
* **Verification:** Added a mirror "shadow" implementation of the original PyTorch code (prefixed `torch_`) alongside comprehensive parity tests to ensure Keras outputs match PyTorch baselines exactly.

**File Changes by Directory:**

`examples/dino_object_detection/models/`

* `config.py`: Dataclasses defining strict configuration schemas for model hyperparameters and training settings.

`examples/dino_object_detection/models/lwdetr/`

* `lwdetr_keras.py`: **Main Entry Point.** Assembles the Keras DINO model (Backbone + Transformer + Heads).
* `lwdetr_test.py`: End-to-end parity test verifying the full Keras forward pass matches PyTorch.
* `torch_lwdetr_for_testing.py`: Reference PyTorch implementation of the LW-DETR model.

`examples/dino_object_detection/models/backbone/`

* `dinov2_backbone_wrapper.py`: Wrapper to adapt Keras-native ViT models for hierarchical feature extraction.
* `projector.py`: Implements `MultiScaleProjector` and `SimpleProjector` (FPN) using depthwise-separable convs.
* `backbone_test.py`: Unit tests for backbone feature extraction and shape consistency.
* `projector_test.py`: Parity tests for the feature pyramid projectors.
* `torch_backbone_for_testing.py`: Reference PyTorch backbone logic.
* `torch_projector_for_testing.py`: Reference PyTorch projector logic.
* `torch_position_encoding_for_testing.py`: Reference PyTorch position embedding logic.

`examples/dino_object_detection/models/transformer_decoder_head/`

* `transformer_kerass.py`: Implements the Transformer Decoder, `PositionEmbeddingSine`, and attention mechanisms.
* `transformer_test.py`: Validates attention scores and decoder outputs against the reference.
* `torch_transformer_for_testing.py`: Reference PyTorch Transformer decoder implementation.

`examples/dino_object_detection/models/segmentation_head/`

* `segmentation_head.py`: Top-level logic for the mask prediction auxiliary head.
* `depthwise_conv_block.py`: Optimized depthwise convolution block used within the mask head.
* `mlp_block.py`: Multi-Layer Perceptron block for feature processing.
* `utils.py`: Implements point sampling and uncertainty-based sampling for mask loss efficiency.
* `segmentation_head_test.py`: Parity tests for the full segmentation head pipeline.
* `depthwise_conv_block_test.py` & `mlp_block_test.py`: Unit tests for specific blocks.
* `utils_test.py`: Verifies numerical stability of point sampling.
* `torch_segmentation_head_for_testing.py`: Reference PyTorch segmentation components.

`examples/dino_object_detection/models/matcher/`

* `matcher.py`: Keras implementation of the `HungarianMatcher` (using SciPy) for bipartite matching.
* `matcher_test.py`: Ensures Keras matching assignments are identical to PyTorch.
* `torch_matcher_for_testing.py`: Reference PyTorch Hungarian matcher.

`examples/dino_object_detection/models/utils/`

* `box_ops.py`: Backend-agnostic box operations (IoU, GIoU, cxcywh/xyxy conversion).
* `metrics.py`: Utilities for logging metrics (WandB) and plotting training curves.
* `benchmark.py`: Tool for profiling model FLOPs and FPS inference speed.
* `early_stopping.py`: Callback for monitoring mAP and halting training.
* `drop_scheduler.py`: Logic for scheduling DropPath rates during training.
* `get_param_dicts.py`: Logic for grouping parameters and applying layer-wise LR decay (ViT specific).
* `obj365_to_coco_model.py`: Utility for transferring weights from Obj365 pre-trained models.
* `misc.py`: General utilities (NestedTensor simulation, smoothed values).
* `files.py`: Helpers for downloading assets.
* `coco_classes.py`: Class ID to label mapping.
* `*_test.py`: Associated unit tests for benchmarking, parameter grouping, and misc utilities.
* `torch_*_for_testing.py`: Corresponding PyTorch reference implementations for all utilities.
@mohamedamrfaroukebrahim mohamedamrfaroukebrahim changed the title adding the transformer head of the object detection of the dino models Implement DINO object detection architecture in Keras 3 Jan 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant