Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
425 changes: 425 additions & 0 deletions docs/design.md

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions lightning_action/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
from argparse import ArgumentParser

from lightning_action import __version__
from lightning_action.cli import formatting
from lightning_action.cli.commands import COMMANDS

Expand All @@ -16,6 +17,12 @@ def build_parser() -> ArgumentParser:
description='Lightning-based action segmentation for behavioral analysis.',
)

parser.add_argument(
'--version',
action='version',
version=f'%(prog)s {__version__}',
)

subparsers = parser.add_subparsers(
dest='command',
required=True,
Expand Down
26 changes: 18 additions & 8 deletions lightning_action/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,21 @@

import logging
from pathlib import Path
from typing import Any
from typing import Any, Literal, get_args

import numpy as np
import pandas as pd
from jaxtyping import Float, Int

logger = logging.getLogger(__name__)

ModelType = Literal[
'temporal-mlp', 'temporalmlp',
'tcn',
'dtcn', 'dilatedtcn',
'lstm', 'gru', 'rnn',
]


def compute_sequences(
data: Float[np.ndarray, 'n_frames ...'] | list,
Expand Down Expand Up @@ -72,7 +79,7 @@ def compute_sequences(


def compute_sequence_pad(
model_type: str,
model_type: ModelType,
default: int | None = None,
**model_params: Any,
) -> int:
Expand Down Expand Up @@ -103,32 +110,35 @@ def compute_sequence_pad(
# Unknown model type with default fallback
pad = compute_sequence_pad('transformer', default=0, num_layers=4)
"""
model_type = model_type.lower()
model_type_lower = model_type.lower()

if model_type in ['temporal-mlp', 'temporalmlp']:
if model_type_lower in ['temporal-mlp', 'temporalmlp']:
return model_params['num_lags']

elif model_type == 'tcn':
elif model_type_lower == 'tcn':
num_layers = model_params['num_layers']
num_lags = model_params['num_lags']
return (2 ** num_layers) * num_lags

elif model_type in ['dtcn', 'dilatedtcn']:
elif model_type_lower in ['dtcn', 'dilatedtcn']:
# dilated TCN with more complex calculation
# dilation of each dilation block is 2 ** layer_num
# 2 conv layers per dilation block
return sum(
[2 * (2 ** n) * model_params['num_lags'] for n in range(model_params['num_layers'])]
)

elif model_type in ['lstm', 'gru', 'rnn']:
elif model_type_lower in ['lstm', 'gru', 'rnn']:
# fixed warmup period for recurrent models
return 4

else:
if default is not None:
return default
raise ValueError(f'Unknown model type: {model_type}')
raise ValueError(
f'Unknown model type: {model_type}. '
f'Valid values: {", ".join(get_args(ModelType))}'
)


def load_marker_csv(file_path: str | Path) -> tuple[
Expand Down
2 changes: 1 addition & 1 deletion lightning_action/data/video_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(

# Calculate TCN padding based on head architecture
self.tcn_padding = compute_sequence_pad(
model_type=head,
model_type=head, # type: ignore[arg-type]
num_lags=num_lags,
num_layers=num_layers,
default=0,
Expand Down
47 changes: 24 additions & 23 deletions lightning_action/models/backbones/resnet_beast.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,25 @@
import torch
import torch.nn as nn

# Hidden sizes for each architecture
BEAST_RESNET_HIDDEN_SIZES = {
'resnet18': 512,
'resnet34': 512,
'resnet50': 2048,
'resnet101': 2048,
'resnet152': 2048,
}

_RESNET_CONFIGS: dict[str, tuple[list[int], bool]] = {
'resnet18': ([2, 2, 2, 2], False),
'resnet34': ([3, 4, 6, 3], False),
'resnet50': ([3, 4, 6, 3], True),
'resnet101': ([3, 4, 23, 3], True),
'resnet152': ([3, 8, 36, 3], True),
}


def get_configs(arch: str = 'resnet50') -> tuple:
def get_configs(arch: str = 'resnet50') -> tuple[list[int], bool]:
"""Get number and type of layers for resnet models.

Args:
Expand All @@ -41,28 +58,12 @@ def get_configs(arch: str = 'resnet50') -> tuple:
Raises:
ValueError: If architecture is not supported.
"""
if arch == 'resnet18':
return [2, 2, 2, 2], False
elif arch == 'resnet34':
return [3, 4, 6, 3], False
elif arch == 'resnet50':
return [3, 4, 6, 3], True
elif arch == 'resnet101':
return [3, 4, 23, 3], True
elif arch == 'resnet152':
return [3, 8, 36, 3], True
else:
raise ValueError(f'{arch} is not a valid ResNet architecture')


# Hidden sizes for each architecture
BEAST_RESNET_HIDDEN_SIZES = {
'resnet18': 512,
'resnet34': 512,
'resnet50': 2048,
'resnet101': 2048,
'resnet152': 2048,
}
if arch not in _RESNET_CONFIGS:
raise ValueError(
f'{arch} is an invalid entry in model.backbone. '
f'Valid values: {", ".join(_RESNET_CONFIGS.keys())}'
)
return _RESNET_CONFIGS[arch]


class ResNetBeastBackbone(nn.Module):
Expand Down
13 changes: 10 additions & 3 deletions lightning_action/models/heads/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
processing capability.
"""

from typing import Literal, get_args

import torch
from jaxtyping import Float
from torch import nn

RnnType = Literal['lstm', 'gru']


class RNN(nn.Module):
"""RNN head for temporal sequence modeling.
Expand All @@ -22,7 +26,7 @@ def __init__(
input_size: int,
num_hid_units: int,
num_layers: int,
rnn_type: str = 'lstm',
rnn_type: RnnType = 'lstm',
bidirectional: bool = False,
dropout_rate: float = 0.0,
seed: int = 42,
Expand Down Expand Up @@ -52,8 +56,11 @@ def __init__(
self.seed = seed

# validate rnn type
if self.rnn_type not in ['lstm', 'gru']:
raise ValueError(f'Invalid rnn_type "{rnn_type}"; must be "lstm" or "gru"')
if self.rnn_type not in get_args(RnnType):
raise ValueError(
f'Invalid rnn_type "{rnn_type}". '
f'Valid values: {", ".join(get_args(RnnType))}'
)

# set random seed
torch.manual_seed(seed)
Expand Down
17 changes: 12 additions & 5 deletions lightning_action/models/heads/tcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@
residual connections for temporal modeling.
"""

from typing import Literal, get_args

import torch
from jaxtyping import Float
from torch import nn

ActivationType = Literal['relu', 'lrelu', 'sigmoid', 'tanh', 'linear']


class DilatedTCN(nn.Module):
"""Dilated Temporal Convolutional Network head.
Expand All @@ -21,7 +25,7 @@ def __init__(
num_hid_units: int,
num_layers: int,
num_lags: int = 1,
activation: str = 'lrelu',
activation: ActivationType = 'lrelu',
dropout_rate: float = 0.2,
seed: int = 42,
) -> None:
Expand All @@ -45,7 +49,7 @@ def __init__(
self.num_hid_units = num_hid_units
self.num_layers = num_layers
self.num_lags = num_lags
self.activation = activation
self.activation: ActivationType = activation
self.dropout_rate = dropout_rate
self.seed = seed

Expand Down Expand Up @@ -142,9 +146,9 @@ def __init__(
kernel_size: int,
stride: int = 1,
dilation: int = 2,
activation: str = 'lrelu',
activation: ActivationType = 'lrelu',
dropout: float = 0.2,
final_activation: str | None = None,
final_activation: ActivationType | None = None,
) -> None:
"""Initialize DilationBlock.

Expand Down Expand Up @@ -245,7 +249,10 @@ def _get_activation_func(activation: str) -> nn.Module:
elif activation == 'linear':
return nn.Identity()
else:
raise ValueError(f'Unsupported activation: {activation}')
raise ValueError(
f'Unsupported activation: {activation}. '
f'Valid values: {", ".join(get_args(ActivationType))}'
)

def _init_weights(self) -> None:
"""Initialize weights with normal distribution."""
Expand Down
11 changes: 9 additions & 2 deletions lightning_action/models/heads/temporalmlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@
which uses 1D convolution for temporal context followed by dense layers.
"""

from typing import Literal, get_args

import torch
import torch.nn as nn
from jaxtyping import Float

ActivationType = Literal['relu', 'lrelu', 'sigmoid', 'tanh', 'linear']


class TemporalMLP(nn.Module):
"""Temporal Multi-Layer Perceptron for sequence encoding.
Expand All @@ -30,7 +34,7 @@ def __init__(
num_hid_units: int,
num_layers: int,
num_lags: int = 5,
activation: str = 'lrelu',
activation: ActivationType = 'lrelu',
dropout_rate: float = 0.0,
seed: int = 42,
) -> None:
Expand Down Expand Up @@ -117,7 +121,10 @@ def _get_activation(self) -> nn.Module:
# `if self.activation != 'linear'`
return nn.Identity()
else:
raise ValueError(f'Unsupported activation: {self.activation}')
raise ValueError(
f'Unsupported activation: {self.activation}. '
f'Valid values: {", ".join(get_args(ActivationType))}'
)

def forward(
self,
Expand Down
15 changes: 12 additions & 3 deletions lightning_action/models/segmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,10 @@ def configure_optimizers(self) -> dict[str, Any]:
params, lr=lr, weight_decay=weight_decay, momentum=momentum
)
else:
raise ValueError(f'Unsupported optimizer type: {optimizer_type}')
raise ValueError(
f'Unsupported optimizer type: {optimizer_type}. '
f'Valid values: adam, adamw, sgd'
)

# Parse scheduler config (support both flat and nested structures)
scheduler_config = optimizer_config.get('scheduler', None)
Expand Down Expand Up @@ -408,7 +411,10 @@ def configure_optimizers(self) -> dict[str, Any]:
}

else:
raise ValueError(f'Unsupported scheduler type: {scheduler_type}')
raise ValueError(
f'Unsupported scheduler type: {scheduler_type}. '
f'Valid values: step, cosine, cosine_warm_restarts, reduce_on_plateau'
)

return {
'optimizer': optimizer,
Expand Down Expand Up @@ -495,7 +501,10 @@ def _build_head(self) -> nn.Module:
seed=self.model_config.get('seed', 42),
)
else:
raise ValueError(f'Unsupported head type: {head_type}')
raise ValueError(
f'Unsupported head type: {head_type}. '
f'Valid values: temporalmlp, rnn, dtcn'
)

def _get_head_output_size(self) -> int:
"""Get the output size of the head network.
Expand Down
5 changes: 4 additions & 1 deletion lightning_action/models/video_segmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,10 @@ def _build_head(self) -> nn.Module:
seed=self.model_config.get('seed', 42),
)
else:
raise ValueError(f'Unsupported head type: {head_type}')
raise ValueError(
f'Unsupported head type: {head_type}. '
f'Valid values: temporalmlp, rnn, dtcn'
)

def _get_head_output_size(self) -> int:
"""Get the output feature dimension of the head."""
Expand Down
13 changes: 12 additions & 1 deletion lightning_action/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
trained_model = train(config, model, output_dir='runs/experiment1')
"""

import inspect
import logging
from pathlib import Path
from typing import Any
Expand Down Expand Up @@ -371,7 +372,17 @@ def build_data_config_from_path(
transform_classes = []
for t_name in transforms:
if not hasattr(transform_module, t_name):
raise ValueError(f"Unknown transform class: {t_name}")
available = sorted(
name for name, cls in inspect.getmembers(
transform_module, inspect.isclass,
)
if issubclass(cls, transform_module.Transform)
and cls is not transform_module.Transform
)
raise ValueError(
f"Unknown transform class: {t_name}. "
f"Available transforms: {', '.join(available)}"
)
transform_classes.append(getattr(transform_module, t_name)())

# Build config
Expand Down
9 changes: 9 additions & 0 deletions tests/cli/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,12 @@ def test_no_args_prints_help_and_exits(self, monkeypatch, capsys):
with pytest.raises(SystemExit) as exc_info:
main()
assert exc_info.value.code == 1

def test_version_flag(self, monkeypatch, capsys):
"""Test that --version prints the version string and exits with code 0."""
monkeypatch.setattr('sys.argv', ['lightning-action', '--version'])
with pytest.raises(SystemExit) as exc_info:
main()
assert exc_info.value.code == 0
captured = capsys.readouterr()
assert 'lightning-action' in captured.out
Loading
Loading