Background
The codebase currently has two separate feature extraction implementations that both extract embeddings from the same ResNet50 backbone but serve different purposes and are maintained independently.
Current State
1. New: Resnet50TimmClassifier.get_features() (added in PR #77)
- Location:
trapdata/ml/models/classification.py:315-325
- Extracts 2048-dim raw embeddings using timm's
forward_features() + adaptive avg pool
- Called during classification in
APIMothClassifier.predict_batch() when include_features=True
- Returns features in API response / worker output
- Features are not saved to the database
- Only works for timm-based ResNet50 classifiers (GlobalMothSpeciesClassifier, QuebecVermont2024, UKDenmark2024, Panama2024)
2. Old: FeatureExtractor class (tracking pipeline)
- Location:
trapdata/ml/models/tracking.py:430-481
- Runs as separate pipeline stage 4 (after classification, before tracking)
- Strips last layer from classifier model:
nn.Sequential(*list(model.children())[:-1])
- Applies L1 normalization per batch
- Saves to
DetectedObject.cnn_features JSON column in database
- Used by the tracking system for cross-frame matching via cosine similarity
- Has its own queue system (
ObjectsWithoutFeaturesQueue)
- Multiple subclasses for different classifiers (MothNonMothFeatureExtractor, QuebecVermontFeatureExtractor, etc.)
Problems
- Redundant forward passes: Both extract from the same backbone but the old extractor loads the model separately and runs a full forward pass on crops that were already classified.
- Redundant model loading: The old
FeatureExtractor loads the same model weights into GPU memory again.
- Inconsistent extraction:
forward_features() + adaptive pool vs nn.Sequential(*children()[:-1]) may produce slightly different features depending on model architecture.
- Two systems to maintain: Separate queue, pipeline stage, and subclass hierarchy for a fundamentally identical operation.
- No cross-use: The new API-only extractor can't be used for tracking, and the old tracking extractor can't be returned in API responses.
Proposed Unification
Goal: Single feature extraction primitive on the classifier, used by both API responses and tracking.
Key insight: If classification already ran on a crop, the backbone features are available for free during that forward pass. Extracting them separately is wasted compute.
Changes
-
Make get_features() the universal primitive — already on Resnet50TimmClassifier, add to other classifier bases as needed.
-
Store features during classification when tracking is enabled:
- Add a
store_features: bool flag (distinct from include_features which controls API response inclusion)
- In
save_results(), write ClassifierResult.features to DetectedObject.cnn_features
- L1 normalization moves to the tracking module (consumer concern, not extraction concern)
-
Remove FeatureExtractor pipeline stage:
- Drop stage 4 from the pipeline
- Remove
FeatureExtractor class and subclasses
- Remove
ObjectsWithoutFeaturesQueue
- Tracking reads features from
DetectedObject.cnn_features (already written by classifier)
-
Extend get_features() to non-timm classifiers:
Resnet50Classifier (custom ResNet50 module): extract from self.backbone + self.avgpool
EfficientNetClassifier: extract via timm's forward_features()
ConvNeXtOrderClassifier: extract via timm's forward_features()
Resnet50Classifier_Turing: extract from self.backbone + self.avgpool
-
Update tracking to normalize on read:
compare_objects() applies L1 normalization before cosine similarity
- Or normalize once when saving (simpler)
Pipeline simplification
Before: Localization → Binary → Species → Feature Extraction → Tracking (5 stages)
After: Localization → Binary → Species (with features) → Tracking (4 stages)
Files to Modify
| File |
Change |
trapdata/ml/models/classification.py |
Add get_features() to EfficientNet, ConvNeXt, Turing classifiers |
trapdata/ml/models/tracking.py |
Remove FeatureExtractor class + subclasses, update tracking to read from classifier-stored features |
trapdata/ml/pipeline.py |
Remove stage 4, update pipeline flow |
trapdata/db/models/queue.py |
Remove ObjectsWithoutFeaturesQueue |
trapdata/api/models/classification.py |
Optionally save features to DB in save_results() |
Risks
- Feature compatibility: If the new extraction produces slightly different features than the old method, existing tracked sequences could have inconsistent features in the database. Mitigate by re-extracting all features on migration.
- Non-timm models: Need to verify
get_features() produces correct output for each backbone architecture before removing the old extractor.
Related
Background
The codebase currently has two separate feature extraction implementations that both extract embeddings from the same ResNet50 backbone but serve different purposes and are maintained independently.
Current State
1. New:
Resnet50TimmClassifier.get_features()(added in PR #77)trapdata/ml/models/classification.py:315-325forward_features()+ adaptive avg poolAPIMothClassifier.predict_batch()wheninclude_features=True2. Old:
FeatureExtractorclass (tracking pipeline)trapdata/ml/models/tracking.py:430-481nn.Sequential(*list(model.children())[:-1])DetectedObject.cnn_featuresJSON column in databaseObjectsWithoutFeaturesQueue)Problems
FeatureExtractorloads the same model weights into GPU memory again.forward_features()+ adaptive pool vsnn.Sequential(*children()[:-1])may produce slightly different features depending on model architecture.Proposed Unification
Goal: Single feature extraction primitive on the classifier, used by both API responses and tracking.
Key insight: If classification already ran on a crop, the backbone features are available for free during that forward pass. Extracting them separately is wasted compute.
Changes
Make
get_features()the universal primitive — already onResnet50TimmClassifier, add to other classifier bases as needed.Store features during classification when tracking is enabled:
store_features: boolflag (distinct frominclude_featureswhich controls API response inclusion)save_results(), writeClassifierResult.featurestoDetectedObject.cnn_featuresRemove
FeatureExtractorpipeline stage:FeatureExtractorclass and subclassesObjectsWithoutFeaturesQueueDetectedObject.cnn_features(already written by classifier)Extend
get_features()to non-timm classifiers:Resnet50Classifier(custom ResNet50 module): extract fromself.backbone+self.avgpoolEfficientNetClassifier: extract via timm'sforward_features()ConvNeXtOrderClassifier: extract via timm'sforward_features()Resnet50Classifier_Turing: extract fromself.backbone+self.avgpoolUpdate tracking to normalize on read:
compare_objects()applies L1 normalization before cosine similarityPipeline simplification
Files to Modify
trapdata/ml/models/classification.pyget_features()to EfficientNet, ConvNeXt, Turing classifierstrapdata/ml/models/tracking.pyFeatureExtractorclass + subclasses, update tracking to read from classifier-stored featurestrapdata/ml/pipeline.pytrapdata/db/models/queue.pyObjectsWithoutFeaturesQueuetrapdata/api/models/classification.pysave_results()Risks
get_features()produces correct output for each backbone architecture before removing the old extractor.Related
get_features()toResnet50TimmClassifierand opt-in API response