Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Lightning Action

![GitHub](https://img.shields.io/github/license/paninski-lab/lightning-action)
![PyPI](https://img.shields.io/pypi/v/lightning-action)
[![codecov](https://codecov.io/gh/paninski-lab/lightning-action/branch/main/graph/badge.svg)](https://codecov.io/gh/paninski-lab/lightning-action)
![PyPI](https://img.shields.io/pypi/v/lightning-action)

A modern action segmentation framework built with PyTorch Lightning for behavioral analysis.

Expand Down Expand Up @@ -163,16 +163,12 @@ Lightning Action automatically logs training metrics to TensorBoard. To visualiz

---

## Contributing
### Contributing

See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on setting up a development environment,
code style, and submitting pull requests.

## License

This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.

## Citation
### Citation

If you use this framework in your research, please cite:

Expand All @@ -187,9 +183,14 @@ If you use this framework in your research, please cite:
}
```

## Acknowledgments

This framework is built upon the work of:
- [PyTorch Lightning](https://lightning.ai/) for the training framework
- [PyTorch](https://pytorch.org/) for the deep learning backend
- Previous action segmentation work from the [Paninski Lab](https://github.com/themattinthehatt/daart)
### Funding

We are grateful for support from the following:
* Gatsby Charitable Foundation GAT3708
* [NIH R50NS145433](https://reporter.nih.gov/search/Hmj4KMmLv0evcYPlPEDa-Q/project-details/11240675)
* [NIH U19NS123716](https://reporter.nih.gov/search/Hmj4KMmLv0evcYPlPEDa-Q/project-details/11141703)
* [NSF 1707398](https://ui.adsabs.harvard.edu/abs/2017nsf....1707398A/abstract)
* [The NSF AI Institute for Artificial and Natural Intelligence](https://ui.adsabs.harvard.edu/abs/2023nsf....2229929Z/abstract)
* Simons Foundation
* Wellcome Trust 216324
* Zuckerman Institute (Columbia University) Team Science
2 changes: 2 additions & 0 deletions lightning_action/cli/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class ArgumentParser(argparse.ArgumentParser):
"""Enhanced argument parser with better formatting."""

def __init__(self, **kwargs: Any) -> None:
"""Initialize the argument parser with custom help formatting."""
super().__init__(
formatter_class=HelpFormatter,
**kwargs,
Expand All @@ -35,6 +36,7 @@ class SubArgumentParser(ArgumentParser):
"""Argument parser for subcommands."""

def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize the subcommand argument parser."""
super().__init__(*args, **kwargs)
self.is_sub_parser = True

Expand Down
81 changes: 81 additions & 0 deletions lightning_action/models/backbones/resnet_beast.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,13 @@ class ResNetBeast(nn.Module):
"""ResNet backbone from beast package."""

def __init__(self, configs: list, bottleneck: bool = False) -> None:
"""Initialize the ResNetBeast backbone.

Args:
configs: list of four integers specifying the number of layers per block.
bottleneck: if True, use bottleneck (1x1/3x3/1x1) layers instead of residual
(3x3/3x3) layers.
"""
super().__init__()

if len(configs) != 4:
Expand Down Expand Up @@ -286,6 +293,14 @@ def __init__(self, configs: list, bottleneck: bool = False) -> None:
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through all five convolutional blocks.

Args:
x: input tensor of shape (B, C, H, W).

Returns:
spatial feature tensor of shape (B, hidden_dim, H', W').
"""
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
Expand All @@ -304,6 +319,15 @@ def __init__(
layers: int,
downsample_method: Literal['conv', 'pool'] = 'conv',
) -> None:
"""Initialize the residual block.

Args:
in_channels: number of input channels.
hidden_channels: number of channels within and output of the block.
layers: number of residual layers to stack.
downsample_method: spatial downsampling strategy — 'conv' uses strided
convolution on the first layer; 'pool' uses max-pooling before the layers.
"""
super().__init__()

if downsample_method == 'conv':
Expand Down Expand Up @@ -342,6 +366,14 @@ def __init__(
self.add_module(f'{i + 1} EncoderLayer', layer)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through all residual layers.

Args:
x: input tensor of shape (B, C, H, W).

Returns:
output tensor of shape (B, hidden_channels, H', W').
"""
for _name, layer in self.named_children():
x = layer(x)
return x
Expand All @@ -358,6 +390,16 @@ def __init__(
layers: int,
downsample_method: Literal['conv', 'pool'] = 'conv',
) -> None:
"""Initialize the bottleneck block.

Args:
in_channels: number of input channels.
hidden_channels: number of channels in the compressed middle convolution.
up_channels: number of output channels after the 1x1 expansion convolution.
layers: number of bottleneck layers to stack.
downsample_method: spatial downsampling strategy — 'conv' uses strided
convolution on the first layer; 'pool' uses max-pooling before the layers.
"""
super().__init__()

if downsample_method == 'conv':
Expand Down Expand Up @@ -392,6 +434,14 @@ def __init__(
self.add_module(f'{i + 1} EncoderLayer', layer)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through all bottleneck layers.

Args:
x: input tensor of shape (B, C, H, W).

Returns:
output tensor of shape (B, up_channels, H', W').
"""
for _name, layer in self.named_children():
x = layer(x)
return x
Expand All @@ -406,6 +456,13 @@ def __init__(
hidden_channels: int,
downsample: bool,
) -> None:
"""Initialize the residual layer.

Args:
in_channels: number of input channels.
hidden_channels: number of output channels.
downsample: if True, apply stride-2 convolution and a matching skip connection.
"""
super().__init__()

if downsample:
Expand Down Expand Up @@ -449,6 +506,14 @@ def __init__(
self.relu = nn.Sequential(nn.ReLU(inplace=True))

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through the residual layer.

Args:
x: input tensor of shape (B, in_channels, H, W).

Returns:
output tensor of shape (B, hidden_channels, H', W').
"""
identity = x
x = self.weight_layer1(x)
x = self.weight_layer2(x)
Expand All @@ -469,6 +534,14 @@ def __init__(
up_channels: int,
downsample: bool,
) -> None:
"""Initialize the bottleneck layer.

Args:
in_channels: number of input channels.
hidden_channels: number of channels in the compressed 3x3 convolution.
up_channels: number of output channels after the 1x1 expansion convolution.
downsample: if True, apply stride-2 convolution and a matching skip connection.
"""
super().__init__()

if downsample:
Expand Down Expand Up @@ -531,6 +604,14 @@ def __init__(
self.relu = nn.Sequential(nn.ReLU(inplace=True))

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through the bottleneck layer.

Args:
x: input tensor of shape (B, in_channels, H, W).

Returns:
output tensor of shape (B, up_channels, H', W').
"""
identity = x
x = self.weight_layer1(x)
x = self.weight_layer2(x)
Expand Down
3 changes: 2 additions & 1 deletion lightning_action/models/backbones/vitmae.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ def get_last_layer_params(self) -> Iterator[nn.Parameter]:
Returns:
Iterator over parameters of the last encoder layer and layernorm.
"""
for param in self.vit_mae.encoder.layer[-1].parameters():
encoder_layers: nn.ModuleList = self.vit_mae.encoder.layer # type: ignore[assignment]
for param in encoder_layers[-1].parameters():
yield param
for param in self.vit_mae.layernorm.parameters():
yield param
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ dependencies = [
"tensorboard",
"torch",
"torchvision",
"transformers",
"transformers (<5.9.0)",
"pandas",
"pyyaml",
]
Expand Down
Loading