Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@ Zachary Teed and Jia Deng<br/>
## Requirements
The code has been tested with PyTorch 1.6 and Cuda 10.1.
```Shell
conda create --name raft
conda env create --name raft --file conda_env.yml
conda activate raft
conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 matplotlib tensorboard scipy opencv -c pytorch
```

## Demos
Expand Down Expand Up @@ -70,6 +69,21 @@ We used the following training schedule in our paper (2 GPUs). Training logs wil
If you have a RTX GPU, training can be accelerated using mixed precision. You can expect similiar results in this setting (1 GPU)
```Shell
./train_mixed.sh
```

## Inference using a pretrained model

```python
import argparse

import raft
from raft.core.raft import RAFT

import torch




Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops. I forgot to finish this bit, and modified demo.py instead.

```

## (Optional) Efficent Implementation
Expand Down
11 changes: 11 additions & 0 deletions conda_env.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
channels:
- pytorch
- defaults
dependencies:
- cudatoolkit=10.1
- matplotlib
- opencv
- pytorch=1.6.0
- scipy
- tensorboard
- torchvision=0.7.0
Empty file removed core/__init__.py
Empty file.
Empty file removed core/utils/__init__.py
Empty file.
Empty file modified demo-frames/frame_0016.png
100755 → 100644
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file modified demo-frames/frame_0017.png
100755 → 100644
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file modified demo-frames/frame_0018.png
100755 → 100644
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file modified demo-frames/frame_0019.png
100755 → 100644
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file modified demo-frames/frame_0020.png
100755 → 100644
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file modified demo-frames/frame_0021.png
100755 → 100644
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file modified demo-frames/frame_0022.png
100755 → 100644
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file modified demo-frames/frame_0023.png
100755 → 100644
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file modified demo-frames/frame_0024.png
100755 → 100644
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file modified demo-frames/frame_0025.png
100755 → 100644
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
75 changes: 0 additions & 75 deletions demo.py

This file was deleted.

Empty file modified download_models.sh
100755 → 100644
Empty file.
12 changes: 12 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[build-system]
requires = [
"setuptools>=42",
# TODO : coordinate versions with PYPI; until then use conda
# "matplotlib",
# "opencv-python",
# "torch==1.6.0",
# "scipy",
# "tensorboard",
# "torchvision==0.7.0",
]
build-backend = "setuptools.build_meta"
24 changes: 24 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
[metadata]
name = raft
version = 0.0.1
author = Zach Teed
author_email = zachteed@gmail.com
description = RAFT: Recurrent All Pairs Field Transforms for Optical Flow
long_description = file: README.md
long_description_content_type = text/markdown
url = https://github.com/princeton-vl/RAFT
project_urls =
Bug Tracker = https://github.com/princeton-vl/RAFT/issues
classifiers =
Programming Language :: Python :: 3
License :: OSI Approved :: BSD License
Operating System :: OS Independent

[options]
package_dir =
= src
packages = find:
python_requires = >=3.6

[options.packages.find]
where = src
File renamed without changes.
File renamed without changes.
1 change: 1 addition & 0 deletions src/raft/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import core
1 change: 1 addition & 0 deletions src/raft/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import raft, update, extractor, datasets, corr, utils
2 changes: 1 addition & 1 deletion core/corr.py → src/raft/core/corr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import torch.nn.functional as F
from utils.utils import bilinear_sampler, coords_grid
from .utils.utils import bilinear_sampler, coords_grid

try:
import alt_cuda_corr
Expand Down
6 changes: 3 additions & 3 deletions core/datasets.py → src/raft/core/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from glob import glob
import os.path as osp

from utils import frame_utils
from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
from .utils import frame_utils
from .utils.augmentor import FlowAugmentor, SparseFlowAugmentor


class FlowDataset(data.Dataset):
Expand Down Expand Up @@ -126,7 +126,7 @@ def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_r
flows = sorted(glob(osp.join(root, '*.flo')))
assert (len(images)//2 == len(flows))

split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
split_list = np.loadtxt(osp.join(osp.dirname(__file__), '..', 'data', 'chairs_split.txt'), dtype=np.int32)
for i in range(len(flows)):
xid = split_list[i]
if (split=='training' and xid==1) or (split=='validation' and xid==2):
Expand Down
File renamed without changes.
8 changes: 4 additions & 4 deletions core/raft.py → src/raft/core/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import torch.nn as nn
import torch.nn.functional as F

from update import BasicUpdateBlock, SmallUpdateBlock
from extractor import BasicEncoder, SmallEncoder
from corr import CorrBlock, AlternateCorrBlock
from utils.utils import bilinear_sampler, coords_grid, upflow8
from .update import BasicUpdateBlock, SmallUpdateBlock
from .extractor import BasicEncoder, SmallEncoder
from .corr import CorrBlock, AlternateCorrBlock
from .utils.utils import bilinear_sampler, coords_grid, upflow8

try:
autocast = torch.cuda.amp.autocast
Expand Down
File renamed without changes.
1 change: 1 addition & 0 deletions src/raft/core/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import augmentor, flow_viz, frame_utils, utils
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
63 changes: 63 additions & 0 deletions src/raft/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import argparse
import os
import cv2
import glob
import numpy as np
import torch
from PIL import Image

from .core.utils import flow_viz
from . import inference



def viz(img, flo):
img = img[0].permute(1,2,0).cpu().numpy()
flo = flo[0].permute(1,2,0).cpu().numpy()

# map flow to rgb image
flo = flow_viz.flow_to_image(flo)
img_flo = np.concatenate([img, flo], axis=0)

# import matplotlib.pyplot as plt
# plt.imshow(img_flo / 255.0)
# plt.show()

cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
cv2.waitKey()


def demo(args):
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = inference.load_model(
args,
device,
args.model)

def log(x):
print(x)
return x

stream = (
np.array(Image.open(log(impath))).astype(np.uint8)
for impath
in sorted(
glob.glob(os.path.join(args.path, '*.png')) +
glob.glob(os.path.join(args.path, '*.jpg'))
)
)

for image1, image2, flow_low, flow_up in inference.process_stream(stream, model, device, iters=20):
viz(image1, flow_up)


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True, help="restore checkpoint")
parser.add_argument('--path', required=True, help="dataset for evaluation")
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
args = parser.parse_args()

demo(args)
11 changes: 5 additions & 6 deletions evaluate.py → src/raft/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import sys
sys.path.append('core')

from PIL import Image
import argparse
Expand All @@ -10,12 +9,12 @@
import torch.nn.functional as F
import matplotlib.pyplot as plt

import datasets
from utils import flow_viz
from utils import frame_utils
from . import datasets
from .core.utils import flow_viz
from .core.utils import frame_utils

from raft import RAFT
from utils.utils import InputPadder, forward_interpolate
from .core.raft import RAFT
from .core.utils.utils import InputPadder, forward_interpolate


@torch.no_grad()
Expand Down
63 changes: 63 additions & 0 deletions src/raft/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch

from .core.utils.utils import InputPadder
from .core.raft import RAFT

def preprocess(image, device):
image = torch.from_numpy(image).permute(2, 0, 1).float()
image = image.unsqueeze(0)
image = image.to(device)
return image


def process_stream(stream, model, device, iters: int = 20):
"""
Processes an image stream and generates tuples of (image1, image2, flow_low, flow_up)
"""
it = iter(stream)
image1 = next(it)
image1 = preprocess(image1, device)

model.eval()
with torch.no_grad():
for image2 in it:
# preprocessing
image2 = preprocess(image2, device)

# pad so shapes match
padder = InputPadder(image1.shape)
image1p, image2p = padder.pad(image1, image2)

# predict the flow
flow_low, flow_up = model(image1p, image2p, iters=iters, test_mode=True)
yield image1p, image2p, flow_low, flow_up

image1 = image2


def cap_stream(cap, n: int = None):
"""
Create an iterable of images from an OpenCV video capture object.
:param n: Maximum number of frames to capture. None means unlimited.
"""
frame_idx = 0
while True:
if n is not None and frame_idx >= n:
break
ret, frame = cap.read()
if not ret:
break
yield frame
frame_idx += 1


def load_model(
raft_args,
device: torch.DeviceObjType,
checkpoint_path: str):
model = RAFT(raft_args)
if device.type == 'cuda':
model = torch.nn.DataParallel(model)
pretrained_weights = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(pretrained_weights)
return model.to(device)
11 changes: 5 additions & 6 deletions train.py → src/raft/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import print_function, division
import sys
sys.path.append('core')

import argparse
import os
Expand All @@ -13,14 +12,14 @@
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import DataLoader
from raft import RAFT
import evaluate
import datasets

from torch.utils.tensorboard import SummaryWriter

from .core.raft import RAFT
from . import evaluate
from . import datasets


try:
from torch.cuda.amp import GradScaler
except:
Expand Down
Loading