Skip to content

Commit

Permalink
Minor Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
XYZ-99 committed Apr 1, 2023
1 parent 06f9e74 commit bcc8b19
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 8 deletions.
26 changes: 23 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Coming soon.

## Installation

Our code is tested on Ubuntu **TODO**.
Our code is tested on (Ubuntu xx.xx **TODO**).

* Clone this repository:
```commandline
Expand Down Expand Up @@ -39,8 +39,29 @@ pip install -r requirements.txt

## Training

**TODO**.
### GraspIPDF

```commandline
python ./network/train.py --config-name ipdf_config \
--exp-dir ./ipdf_train
```

### GraspGlow

```commandline
python ./network/train.py --config-name glow_config \
--exp-dir ./glow_train
```

### ContactNet

```commandline
python ./network/train.py TODO
```

### Policy

TODO

## Evaluation

Expand All @@ -63,4 +84,3 @@ pip install -r requirements.txt

* [PointNet++](https://github.com/rusty1s/pytorch_geometric)
* [Implicit PDF](https://github.com/google-research/google-research/tree/master/implicit_pdf)

8 changes: 4 additions & 4 deletions network/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def main(cfg):

""" Logging """
log_dir = cfg["exp_dir"]
os.makedirs(log_dir, exist_ok=True)

logger = logging.getLogger("TrainModel")
logger.setLevel(logging.INFO)
Expand All @@ -61,8 +62,7 @@ def main(cfg):
test_loader = get_dex_dataloader(cfg, "test")

""" Trainer """
input_size = len(train_loader) #train_loader.dataset[0]["features"].shape[0]
trainer = Trainer(input_size, cfg, logger)
trainer = Trainer(cfg, logger)
start_epoch = trainer.resume()

""" Test """
Expand Down Expand Up @@ -91,7 +91,7 @@ def test_all(dataloader, mode, epoch):
if trainer.iteration % cfg["freq"]["plot"] == 0:
cnt = train_loss.pop("cnt")
log_loss_summary(train_loss, cnt,
lambda x, y: logger.info(f"Train {x} is {y}"))
lambda x, y: logger.info(f"Train {x} is {y}"))
log_tensorboard(writer, "train", train_loss, cnt, epoch)

trainer.step_epoch()
Expand All @@ -114,7 +114,7 @@ def test_all(dataloader, mode, epoch):

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--config-name", type=str, default="ipdf_config.yaml")
parser.add_argument("--config-name", type=str, default="ipdf_config")
parser.add_argument("--exp-dir", type=str, help="E.g., './ipdf_train'.")
return parser.parse_args()

Expand Down
2 changes: 1 addition & 1 deletion network/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def get_last_model(dirname, key=""):


class Trainer(nn.Module):
def __init__(self, input_size, cfg, logger):
def __init__(self, cfg, logger):
super(Trainer, self).__init__()

self.cfg = cfg
Expand Down

0 comments on commit bcc8b19

Please sign in to comment.