Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@ jobs:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10"]
pytorch-version: ["1.4.0", "1.5.1", "1.6.0", "1.7.1", "1.8", "1.9", "1.10", "1.11", "1.12", "1.13", "2.0", "2.1", "2.2"]
pytorch-version: ["1.7.1", "1.8", "1.9", "1.10", "1.11", "1.12", "1.13", "2.0", "2.1", "2.2"]
include:
- python-version: 3.11
pytorch-version: 2.0
- python-version: 3.11
pytorch-version: 2.1
- python-version: 3.11
Expand Down Expand Up @@ -50,30 +48,30 @@ jobs:
pytorch-version: 1.9
- python-version: 3.10
pytorch-version: 1.10
- python-version: 3.10
pytorch-version: 2.0

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
- uses: actions/checkout@v4
- name: Install uv and set the python version
uses: astral-sh/setup-uv@v6
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade uv
uv pip install --system pytest pytest-cov
uv pip install --system torch==${{ matrix.pytorch-version }} torchvision transformers
uv add torch==${{ matrix.pytorch-version }}
uv sync
- name: mypy
if: ${{ matrix.pytorch-version == '2.2' }}
if: ${{ matrix.python-version == '3.11' && matrix.pytorch-version == '2.2' }}
run: |
uv pip install --system mypy==1.9.0
mypy --install-types --non-interactive .
uv run mypy --install-types --non-interactive .
- name: pytest
if: ${{ matrix.pytorch-version == '2.2' }}
run: |
pytest --cov=torchinfo --cov-report= --durations=0
uv run pytest --cov=torchinfo --cov-report= --durations=0
- name: pytest
if: ${{ matrix.pytorch-version != '2.2' }}
run: |
pytest --no-output -k "not test_eval_order_doesnt_matter and not test_google and not test_uninitialized_tensor and not test_input_size_half_precision and not test_recursive_with_missing_layers and not test_flan_t5_small"
uv run pytest --no-output -k "not test_eval_order_doesnt_matter and not test_google and not test_uninitialized_tensor and not test_input_size_half_precision and not test_recursive_with_missing_layers and not test_flan_t5_small"
- name: codecov
uses: codecov/codecov-action@v1
uses: codecov/codecov-action@v5
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,24 @@ ci:
skip: [mypy, pytest]
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.10
rev: v0.12.8
hooks:
- id: ruff
- id: ruff-check
args: [--fix]
- id: ruff-format

- repo: local
hooks:
- id: mypy
name: mypy
entry: mypy
entry: uv run mypy
language: python
types: [python]
require_serial: true

- id: pytest
name: pytest
entry: pytest --cov=torchinfo --cov-report=html --durations=0
entry: uv run pytest --cov=torchinfo --cov-report=html --durations=0
language: python
types: [python]
always_run: true
Expand Down
119 changes: 119 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
[project]
name = "torchinfo"
description = "Model summary in PyTorch, based off of the original torchsummary."
authors = [
{ name = "Tyler Yep @tyleryep", email = "tyep@cs.stanford.edu" },
]
readme = "README.md"
classifiers = [
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]
keywords = [
"torch pytorch torchsummary torch-summary summary keras deep-learning ml torchinfo torch-info visualize model statistics layer stats",
]
dynamic = [
"version",
]
requires-python = ">=3.8"
dependencies = [
"torch",
"torchvision",
"numpy",
]

[dependency-groups]
dev = [
"codecov",
"mypy",
"pre-commit",
"pytest",
"pytest-cov",
"ruff",
"setuptools<70",
"transformers",
"types-setuptools",
"types-tqdm",
]

[project.urls]
Homepage = "https://github.com/tyleryep/torchinfo"

[build-system]
requires = [
"setuptools>=61.2",
]
build-backend = "setuptools.build_meta"

[tool.setuptools]
packages = [
"torchinfo",
]
include-package-data = true

[tool.setuptools.package-data]
torchinfo = [
"py.typed",
]

[tool.setuptools.dynamic.version]
attr = "torchinfo.__version__"

[tool.mypy]
strict = true
warn_unreachable = true
disallow_any_unimported = true
extra_checks = true
enable_error_code = "ignore-without-code"

[tool.ruff]
target-version = "py38"
lint.select = ["ALL"]
lint.ignore = [
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed
"C901", # function is too complex (12 > 10)
"COM812", # Trailing comma missing
"D", # Docstring rules
"EM101", # Exception must not use a string literal, assign to variable first
"EM102", # Exception must not use an f-string literal, assign to variable first
"ERA001", # Found commented-out code
"FBT001", # Boolean positional arg in function definition
"FBT002", # Boolean default value in function definition
"FBT003", # Boolean positional value in function call
"FIX002", # Line contains TODO
"PLR0911", # Too many return statements (11 > 6)
"PLR2004", # Magic value used in comparison, consider replacing 2 with a constant variable
"PLR0912", # Too many branches
"PLR0913", # Too many arguments to function call
"PLR0915", # Too many statements
"S101", # Use of `assert` detected
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
"T201", # print() found
"T203", # pprint() found
"TD002", # Missing author in TODO; try: `# TODO(<author_name>): ...`
"TD003", # Missing issue link on the line following this TODO
"TD005", # Missing issue description after `TODO`
"TRY003", # Avoid specifying long messages outside the exception class

# torchinfo-specific ignores
"N803", # Argument name `A_i` should be lowercase
"N806", # Variable `G` in function should be lowercase
"BLE001", # Do not catch blind exception: `Exception`
"PLW0602", # Using global for `_cached_forward_pass` but no assignment is done
"PLW0603", # Using the global statement to update `_cached_forward_pass` is discouraged
"PLW2901", # `for` loop variable `name` overwritten by assignment target
"RUF005", # Consider unpack instead of concatenation
"SIM108", # [*] Use ternary operator `model_mode = Mode.EVAL if mode is None else Mode(mode)` instead of `if`-`else`-block
"SLF001", # Private member accessed: `_modules`
"TC002", # Move third-party import into a type-checking block
"TRY004", # Prefer `TypeError` exception for invalid type
"TRY301", # Abstract `raise` to an inner function
]
exclude = ["tests"] # TODO: check tests too
7 changes: 0 additions & 7 deletions requirements-dev.txt

This file was deleted.

3 changes: 0 additions & 3 deletions requirements.txt

This file was deleted.

46 changes: 0 additions & 46 deletions ruff.toml

This file was deleted.

39 changes: 0 additions & 39 deletions setup.cfg

This file was deleted.

4 changes: 2 additions & 2 deletions tests/fixtures/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def forward(self, x: torch.Tensor, scalar: float) -> torch.Tensor:
out = self.conv1(out)
else:
out = self.conv2(out)
return out
return out # type: ignore[no-any-return]


class LSTMNet(nn.Module):
Expand Down Expand Up @@ -468,7 +468,7 @@ def forward(self, batch: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
output = self.dropout_layer(ht[-1])
output = self.hidden2out(output)
output = F.log_softmax(output, dim=1)
return cast(torch.Tensor, output)
return output


class ContainerModule(nn.Module):
Expand Down
4 changes: 1 addition & 3 deletions tests/torchinfo_xl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
from torchinfo.enums import ColumnSettings

if version.parse(torch.__version__) >= version.parse("1.8"):
from transformers import ( # type: ignore[import-untyped]
AutoModelForSeq2SeqLM,
)
from transformers import AutoModelForSeq2SeqLM


def test_ascii_only() -> None:
Expand Down
3 changes: 2 additions & 1 deletion torchinfo/torchinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ class name as the key. If the forward pass is an expensive operation,
See torchinfo/model_statistics.py for more information.
"""
input_data_specified = input_data is not None or input_size is not None
columns: tuple[ColumnSettings, ...]
if col_names is None:
columns = (
DEFAULT_COLUMN_NAMES
Expand Down Expand Up @@ -514,7 +515,7 @@ def get_total_memory_used(data: CORRECTED_INPUT_DATA_TYPE) -> int:
else sum
),
)
return cast(int, result)
return cast("int", result)


def get_input_tensor(
Expand Down
Loading
Loading