Skip to content

Unify feature extraction: replace standalone FeatureExtractor with classifier-integrated get_features() #123

Description

@mihow

Background

The codebase currently has two separate feature extraction implementations that both extract embeddings from the same ResNet50 backbone but serve different purposes and are maintained independently.

Current State

1. New: Resnet50TimmClassifier.get_features() (added in PR #77)

  • Location: trapdata/ml/models/classification.py:315-325
  • Extracts 2048-dim raw embeddings using timm's forward_features() + adaptive avg pool
  • Called during classification in APIMothClassifier.predict_batch() when include_features=True
  • Returns features in API response / worker output
  • Features are not saved to the database
  • Only works for timm-based ResNet50 classifiers (GlobalMothSpeciesClassifier, QuebecVermont2024, UKDenmark2024, Panama2024)

2. Old: FeatureExtractor class (tracking pipeline)

  • Location: trapdata/ml/models/tracking.py:430-481
  • Runs as separate pipeline stage 4 (after classification, before tracking)
  • Strips last layer from classifier model: nn.Sequential(*list(model.children())[:-1])
  • Applies L1 normalization per batch
  • Saves to DetectedObject.cnn_features JSON column in database
  • Used by the tracking system for cross-frame matching via cosine similarity
  • Has its own queue system (ObjectsWithoutFeaturesQueue)
  • Multiple subclasses for different classifiers (MothNonMothFeatureExtractor, QuebecVermontFeatureExtractor, etc.)

Problems

  1. Redundant forward passes: Both extract from the same backbone but the old extractor loads the model separately and runs a full forward pass on crops that were already classified.
  2. Redundant model loading: The old FeatureExtractor loads the same model weights into GPU memory again.
  3. Inconsistent extraction: forward_features() + adaptive pool vs nn.Sequential(*children()[:-1]) may produce slightly different features depending on model architecture.
  4. Two systems to maintain: Separate queue, pipeline stage, and subclass hierarchy for a fundamentally identical operation.
  5. No cross-use: The new API-only extractor can't be used for tracking, and the old tracking extractor can't be returned in API responses.

Proposed Unification

Goal: Single feature extraction primitive on the classifier, used by both API responses and tracking.

Key insight: If classification already ran on a crop, the backbone features are available for free during that forward pass. Extracting them separately is wasted compute.

Changes

  1. Make get_features() the universal primitive — already on Resnet50TimmClassifier, add to other classifier bases as needed.

  2. Store features during classification when tracking is enabled:

    • Add a store_features: bool flag (distinct from include_features which controls API response inclusion)
    • In save_results(), write ClassifierResult.features to DetectedObject.cnn_features
    • L1 normalization moves to the tracking module (consumer concern, not extraction concern)
  3. Remove FeatureExtractor pipeline stage:

    • Drop stage 4 from the pipeline
    • Remove FeatureExtractor class and subclasses
    • Remove ObjectsWithoutFeaturesQueue
    • Tracking reads features from DetectedObject.cnn_features (already written by classifier)
  4. Extend get_features() to non-timm classifiers:

    • Resnet50Classifier (custom ResNet50 module): extract from self.backbone + self.avgpool
    • EfficientNetClassifier: extract via timm's forward_features()
    • ConvNeXtOrderClassifier: extract via timm's forward_features()
    • Resnet50Classifier_Turing: extract from self.backbone + self.avgpool
  5. Update tracking to normalize on read:

    • compare_objects() applies L1 normalization before cosine similarity
    • Or normalize once when saving (simpler)

Pipeline simplification

Before: Localization → Binary → Species → Feature Extraction → Tracking  (5 stages)
After:  Localization → Binary → Species (with features) → Tracking        (4 stages)

Files to Modify

File Change
trapdata/ml/models/classification.py Add get_features() to EfficientNet, ConvNeXt, Turing classifiers
trapdata/ml/models/tracking.py Remove FeatureExtractor class + subclasses, update tracking to read from classifier-stored features
trapdata/ml/pipeline.py Remove stage 4, update pipeline flow
trapdata/db/models/queue.py Remove ObjectsWithoutFeaturesQueue
trapdata/api/models/classification.py Optionally save features to DB in save_results()

Risks

  • Feature compatibility: If the new extraction produces slightly different features than the old method, existing tracked sequences could have inconsistent features in the database. Mitigate by re-extracting all features on migration.
  • Non-timm models: Need to verify get_features() produces correct output for each backbone architecture before removing the old extractor.

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions