Skip to content

Commit

Permalink
Merge branch 'master' of github.com:jmhb0/o2vae
Browse files Browse the repository at this point in the history
  • Loading branch information
jmhb0 committed Jan 12, 2023
2 parents 811d187 + 74ebda2 commit 895b169
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ To train an o2-vae model, edit `./run.bash` to point to the right config file, a
```
bash run.bash
```
The example commands in that script are for the demo dataset and configs (mext section). Training these demos on GPUs (nvidia-rtx) with the default configs in `run.bash` takes <1min per training epoch for both demo datasets. Training converges in about 50 epochs.

**Important** check the terminal for the location of the saved models. Something like:
> Logging directory is `wandb/<log_dir>`
Expand All @@ -86,10 +88,10 @@ We provide two demo datasets, [o2-mnist](./data/o2_mnist/README.md) and [MEFS](.
python data/generate_o2mnist.py
bash data/mefs/unzip_mefs.bash
```
They each have a config file `configs/config_o2mnst.py` and `configs/config_mefs.py`. A model can be trained using the script above, OR they can be run in notebooks `notebooks/`
They each have a config file `configs/config_o2mnst.py` and `configs/config_mefs.py`. A model can be trained using the script above, OR they can be run in notebooks `examples/`

### Running in a notebook
Examples notebooks for training models are in `notebooks/`. This is mostly the same code as `run.py` but without any logging.
Examples notebooks for training models are in `examples/`. This is mostly the same code as `run.py` but without any logging.

## <a name="usage2"/> Usage - using representation for analysis
### Recovering trained models
Expand All @@ -102,7 +104,7 @@ import torch
from configs.<my_config> import config
model=run.get_datasets_from_config(config)
fname_model=wandb/<run_name>/files_model.pt`
fname_model=wandb/<run_name>/files_model.pt
saved_model=torch.load(fname_model)
model.load_state_dict(saved_model['state_dict'])
```
Expand Down

0 comments on commit 895b169

Please sign in to comment.