This is the official code release for our ECCV 2022 paper, "Particle Video Revisited: Tracking Through Occlusions Using Point Trajectories". [Paper] [Project Page]
Very soon, this repo will also host the PIPs++ model from our ICCV 2023 paper, "PointOdyssey: A Large-Scale Synthetic Dataset for Long-Term Point Tracking". [Paper] [Project Page]
-
Added new reference model, trained on a slightly harder version of FlyingThings++ and larger batch size.
-
Updated
train.py
andflyingthings.py
with these settings. Get the new reference model (as before) with./get_reference_model.sh
or from HuggingFace. -
New results are better than the paper. Here they are:
BADJA:
bear: 76.1 camel: 91.6 cows: 87.7 dog-agility: 31.0 dog: 45.4 horsejump-high: 60.9 horsejump-low: 58.1 avg: 64.4
CroHD:
vis: 4.57 occ: 7.71
FlyingThings:
pips: ate_vis = 6.03, ate_occ = 19.56 raft: ate_vis = 16.65, ate_occ = 43.22 dino: ate_vis = 42.98, ate_occ = 76.78
The lines below should set up a fresh environment with everything you need:
conda create --name pips
source activate pips
conda install pytorch=1.12.0 torchvision=0.13.0 cudatoolkit=11.3 -c pytorch
conda install pip
pip install -r requirements.txt
If you want to run this code on Apple's silicon, use the following:
conda create --name pips python=3.9
source activate pips
conda install pytorch torchvision torchaudio -c pytorch
conda install pip
pip install -r requirements.txt
At the time of this fix, MPS does not support all operators required for this code to run so you will also need:
export PYTORCH_ENABLE_MPS_FALLBACK=1
To download our reference model, download the weights from Hugging Face.
Alternatively, you can run this:
sh get_reference_model.sh
To run this model on a sample video, run this:
python demo.py
This will run the model on a sequence included in demo_images/
.
For each 8-frame subsequence, the model will return trajs_e
. This is estimated trajectory data for the particles, shaped B,S,N,2
, where S
is the sequence length and N
is the number of particles, and 2
is the x
and y
coordinates. The script will also produce tensorboard logs with visualizations, which go into logs_demo/
, as well as a few gifs in ./*.gif
.
In the tensorboard for logs_demo/
you should be able to find visualizations like this:
To track points across arbitrarily-long videos, run this:
python chain_demo.py
In the tensorboard for logs_chain_demo/
you should be able to find visualizations like this:
This type of tracking is much more challenging, so you can expect to see more failures here. In particular, here we are using our visibility-aware chaining method, so mistakes tend to propagate into the future.
The original video is https://www.youtube.com/watch?v=LaqYt0EZIkQ
. The file demo_images/extract_frames.sh
shows the ffmpeg command we used to export frames from the mp4.
To inspect our PIPs model implementation, the main file to look at is nets/pips.py
To download our exact FlyingThings++ dataset, try this link. If the link doesn't work, contact me for a secondary link, or create the data from the original FlyingThings, as described next.
To create our FlyingThings++ dataset, first download FlyingThings. The data should look like:
../flyingthings/optical_flow/
../flyingthings/object_index/
../flyingthings/frames_cleanpass_webp/
Once you have the flows and masks, you can run python make_trajs.py
. This will put 80G of trajectory data into:
../flyingthings/trajs_ad/
In parallel, you can run python make_occlusions.py
. This will put 537M of occlusion data into:
../flyingthings/occluders_al
This data will be loaded and joined with corresponding rgb by the FlyingThingsDataset
class in flyingthingsdataset.py
, when training and testing.
(The suffixes "ad" and "al" are version counters.)
Once loaded by the dataloader (flyingthingsdataset.py
), the RGB will look like this:
The corresponding trajectories will look like this:
To train a model on the flyingthings++ dataset:
python train.py
First it should print some diagnostic information about the model and data. Then, it should print a message for each training step, indicating the model name, progress, read time, iteration time, and loss.
model_name 1_8_128_I6_3e-4_A_tb89_21:34:46
loading FlyingThingsDataset [...] found 13085 samples in ../flyingthings (dset=TRAIN, subset=all, version=ad)
loading occluders [...] found 7856 occluders in ../flyingthings (dset=TRAIN, subset=all, version=al)
not using augs in val
loading FlyingThingsDataset [...] found 2542 samples in ../flyingthings (dset=TEST, subset=all, version=ad)
loading occluders...found 1631 occluders in ../flyingthings (dset=TEST, subset=all, version=al)
warning: updated load_fails (on this worker): 1/13085...
1_8_128_I6_3e-4_A_tb89_21:34:46; step 000001/100000; rtime 9.79; itime 20.24; loss = 40.30593
1_8_128_I6_3e-4_A_tb89_21:34:46; step 000002/100000; rtime 0.01; itime 0.37; loss = 43.12448
1_8_128_I6_3e-4_A_tb89_21:34:46; step 000003/100000; rtime 0.01; itime 0.36; loss = 36.60324
1_8_128_I6_3e-4_A_tb89_21:34:46; step 000004/100000; rtime 0.01; itime 0.38; loss = 40.91223
1_8_128_I6_3e-4_A_tb89_21:34:46; step 000005/100000; rtime 0.01; itime 0.35; loss = 79.32227
1_8_128_I6_3e-4_A_tb89_21:34:46; step 000006/100000; rtime 0.01; itime 0.53; loss = 22.14469
1_8_128_I6_3e-4_A_tb89_21:34:46; step 000007/100000; rtime 0.04; itime 0.46; loss = 24.75386
[...]
Occasional load_fails
warnings are typically harmless. They indicate when the dataloader fails to get N
trajectories for a given video, which simply causes a retry. If you greatly increase N
(the number of trajectories), or reduce the crop size, you can expect this warning to occur more frequently, since these constraints make it more difficult to find viable samples. As long as your rtime
(read time) is small, then things are basically OK.
To reproduce the result in the paper, you should train with 4 gpus, with horizontal and vertical flips, with a command like this:
python train.py --horz_flip=True --vert_flip=True --device_ids=[0,1,2,3]
We provide evaluation scripts for all of the datasets reported in the paper. By default, these scripts will evaluate a PIPs model, with a checkpoint folder specified by the --init_dir
argument.
You can also try a baseline, with --modeltype='raft'
or --modeltype='dino'
. To do this, you will also want to download a RAFT model. The DINO model should download itself, since torch makes this easy.
The CroHD head tracking data comes from the "Get all data" link on the Head Tracking 21 MOT Challenge page. Downloading and unzipping that should give you the folders HT21 and HT21Labels, which our dataloader relies on. After downloading the data (and potentially editing the dataset_location
in crohddataset.py
), you can evaluate the model in CroHD with: python test_on_crohd.py
To evaluate the model in Flyingthings++, first set up the data as described in the earlier section of this readme, then: python test_on_flt.py
The DAVIS dataset comes from the "TrainVal - Images and Annotations - Full-Resolution" link, on the DAVIS Challenge page. After downloading the data (and potentially editing the data_path
in test_on_davis.py
), you can visualize the model's outputs in DAVIS with: python test_on_davis.py
To evaluate the model in BAJDA, first follow the instructions at the BADJA repo. This will involve downloading DAVIS trainval full-resolution data. After downloading the data (and potentially editing BADJA_PATH
in badjadataset.py
), you can evaluate the model in BADJA with: python test_on_badja.py
If you use this code for your research, please cite:
Particle Video Revisited: Tracking Through Occlusions Using Point Trajectories. Adam W. Harley, Zhaoyuan Fang, Katerina Fragkiadaki. In ECCV 2022.
Bibtex:
@inproceedings{harley2022particle,
title={Particle Video Revisited: Tracking Through Occlusions Using Point Trajectories},
author={Adam W Harley and Zhaoyuan Fang and Katerina Fragkiadaki},
booktitle={ECCV},
year={2022}
}