Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions alt_cuda_corr/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@


setup(
name='correlation',
name='raft-alt-cuda-corr',
ext_modules=[
CUDAExtension('alt_cuda_corr',
CUDAExtension('raft_alt_cuda_corr',
sources=['correlation.cpp', 'correlation_kernel.cu'],
extra_compile_args={'cxx': [], 'nvcc': ['-O3']}),
],
Expand Down
11 changes: 11 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[build-system]
build-backend = "setuptools.build_meta"
requires = [ "setuptools" ]

[project]
name = "raft"
version = "0.0.1"
dependencies = [ "raft-alt-cuda-corr" ]

[tool.setuptools.packages.find]
exclude = [ "alt_cuda_corr" ]
File renamed without changes.
File renamed without changes.
8 changes: 4 additions & 4 deletions core/corr.py → raft/core/corr.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import torch
import torch.nn.functional as F
from utils.utils import bilinear_sampler, coords_grid
from raft.core.utils.utils import bilinear_sampler, coords_grid

try:
import alt_cuda_corr
import raft_alt_cuda_corr
except:
# alt_cuda_corr is not compiled
# raft_alt_cuda_corr is not compiled
pass


Expand Down Expand Up @@ -83,7 +83,7 @@ def __call__(self, coords):
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()

coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
corr, = raft_alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
corr_list.append(corr.squeeze(1))

corr = torch.stack(corr_list, dim=1)
Expand Down
4 changes: 2 additions & 2 deletions core/datasets.py → raft/core/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from glob import glob
import os.path as osp

from utils import frame_utils
from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
from raft.core.utils import frame_utils
from raft.core.utils.augmentor import FlowAugmentor, SparseFlowAugmentor


class FlowDataset(data.Dataset):
Expand Down
File renamed without changes.
8 changes: 4 additions & 4 deletions core/raft.py → raft/core/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import torch.nn as nn
import torch.nn.functional as F

from update import BasicUpdateBlock, SmallUpdateBlock
from extractor import BasicEncoder, SmallEncoder
from corr import CorrBlock, AlternateCorrBlock
from utils.utils import bilinear_sampler, coords_grid, upflow8
from raft.core.update import BasicUpdateBlock, SmallUpdateBlock
from raft.core.extractor import BasicEncoder, SmallEncoder
from raft.core.corr import CorrBlock, AlternateCorrBlock
from raft.core.utils.utils import bilinear_sampler, coords_grid, upflow8

try:
autocast = torch.cuda.amp.autocast
Expand Down
File renamed without changes.
Empty file added raft/core/utils/__init__.py
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
8 changes: 3 additions & 5 deletions demo.py → raft/demo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import sys
sys.path.append('core')

import argparse
import os
Expand All @@ -10,13 +9,12 @@
from PIL import Image

from raft import RAFT
from utils import flow_viz
from utils.utils import InputPadder


from raft.core.utils import flow_viz
from raft.core.utils.utils import InputPadder

DEVICE = 'cuda'


def load_image(imfile):
img = np.array(Image.open(imfile)).astype(np.uint8)
img = torch.from_numpy(img).permute(2, 0, 1).float()
Expand Down
7 changes: 3 additions & 4 deletions evaluate.py → raft/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import sys
sys.path.append('core')

from PIL import Image
import argparse
Expand All @@ -11,11 +10,11 @@
import matplotlib.pyplot as plt

import datasets
from utils import flow_viz
from utils import frame_utils
from raft.core.utils import flow_viz
from raft.core.utils import frame_utils

from raft import RAFT
from utils.utils import InputPadder, forward_interpolate
from raft.core.utils.utils import InputPadder, forward_interpolate


@torch.no_grad()
Expand Down
11 changes: 6 additions & 5 deletions train.py → raft/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import print_function, division
import sys
sys.path.append('core')

import argparse
import os
Expand All @@ -16,11 +15,13 @@

from torch.utils.data import DataLoader
from raft import RAFT
import evaluate
import raft.evaluate
import datasets

from torch.utils.tensorboard import SummaryWriter

pass

try:
from torch.cuda.amp import GradScaler
except:
Expand Down Expand Up @@ -189,11 +190,11 @@ def train(args):
results = {}
for val_dataset in args.validation:
if val_dataset == 'chairs':
results.update(evaluate.validate_chairs(model.module))
results.update(raft.evaluate.validate_chairs(model.module))
elif val_dataset == 'sintel':
results.update(evaluate.validate_sintel(model.module))
results.update(raft.evaluate.validate_sintel(model.module))
elif val_dataset == 'kitti':
results.update(evaluate.validate_kitti(model.module))
results.update(raft.evaluate.validate_kitti(model.module))

logger.write_dict(results)

Expand Down