Skip to content

Commit

Permalink
update readme and pip package
Browse files Browse the repository at this point in the history
  • Loading branch information
jiaweizzhao committed Mar 15, 2024
1 parent a6bc165 commit 33f2c4f
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 13 deletions.
56 changes: 56 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,80 @@ As a gradient projection method, GaLore is independent of the choice of optimize
<img src="imgs/galore_code_box.png" alt="Image 2" style="width: 550px; margin: 0 auto;">
</div>

## News
Thanks everyone for the interest in GaLore!

**We are working on the offical release of GaLore.** In the meanwhile, please feel free to try the pre-release version and provide feedback to us. Currently, the pre-release version (e.g., GaLore optimizers) should provide a decent memory reduction and accurate simulation of GaLore algorithm.

The official release of GaLore will include:

1. Per-layer weight updates for multi-GPU training (DDP and FSDP) (working with [PyTorch](https://pytorch.org/)).
2. Memory-efficient low-rank gradient accumulation (working with [PyTorch](https://pytorch.org/)).
3. Optimized `GaLoreAdamW8bit` (working with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)).

We would like to express our gratitude to the community members who have been actively working on integrating GaLore into different platforms, including [HuggingFace](https://github.com/huggingface/transformers/pull/29588), [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory), and [Axolotl](https://github.com/OpenAccess-AI-Collective/axolotl/pull/1370). Join our Slack workspace [GaLore-Social](https://join.slack.com/t/galore-social/shared_invite/zt-2ev152px0-DguuQ5WRTLQjtq2C88HBvQ) to engage in discussions with us.

## Discussion [(GaLore-Social)](https://join.slack.com/t/galore-social/shared_invite/zt-2ev152px0-DguuQ5WRTLQjtq2C88HBvQ)

We welcome any discussions, questions, and feedback on GaLore. Please join our Slack workspace [GaLore-Social](https://join.slack.com/t/galore-social/shared_invite/zt-2ev152px0-DguuQ5WRTLQjtq2C88HBvQ) to discuss with us and the community.


## Installation

### Install GaLore optimizer
Install from pip:
```bash
pip install galore-torch
```

or if you want to install from source:

```bash
git clone [email protected]:jiaweizzhao/GaLore.git
cd GaLore
pip install -e .
```

### Install experiment dependencies

```bash
pip install -r exp_requirements.txt
```

## Usage

### Save optimizer memory using GaLore optimizers

```python
from galore_torch import GaLoreAdamW, GaLoreAdamW8bit, GaLoreAdafactor
# define param groups as galore_params and non_galore_params
param_groups = [{'params': non_galore_params},
{'params': galore_params, 'rank': 128, 'update_proj_gap': 200, 'scale': 0.25, 'proj_type': 'std'}]
optimizer = GaLoreAdamW(param_groups, lr=0.01)
```
### Save weight gradient memory using per-layer weight updates

We use `register_post_accumulate_grad_hook` provided by [PyTorch](https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html) to enable per-layer weight updates. An example is shown below:

```python
# define an optimizer for each parameter p, and store them in optimizer_dict
for p in model.parameters():
if p.requires_grad:
optimizer_dict[p] = GaLoreAdamW([{'params': p, 'rank': 128, 'update_proj_gap': 200, 'scale': 0.25, 'proj_type': 'std'}], lr=0.01)

# define a hook function to update the parameter p during the backward pass
def optimizer_hook(p):
if p.grad is None:
return
optimizer_dict[p].step()
optimizer_dict[p].zero_grad()

# Register the hook onto every parameter
for p in model.parameters():
if p.requires_grad:
p.register_post_accumulate_grad_hook(optimizer_hook)
```
More details can be found in [torchrun_main.py](https://github.com/jiaweizzhao/GaLore/blob/a6bc1650984b1c090a4e108d7c0e3109ee7ad844/torchrun_main.py#L334).

## Benchmark 1: Pre-Training LLaMA on C4 dataset
`torchrun_main.py` is the main script for training LLaMA models on C4 with GaLore. Our benchmark scripts for various sizes of models are in `scripts/benchmark_c4` folder.
Expand Down
14 changes: 14 additions & 0 deletions exp_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
torch
transformers==4.31.0
tokenizers
datasets
peft
wandb
loguru
nvitop
lion-pytorch
matplotlib
bitsandbytes
scipy
scikit-learn
evaluate
15 changes: 2 additions & 13 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,3 @@
torch
transformers==4.31.0
tokenizers
datasets
peft
wandb
loguru
nvitop
lion-pytorch
matplotlib
bitsandbytes
scipy
scikit-learn
evaluate
transformers
bitsandbytes

0 comments on commit 33f2c4f

Please sign in to comment.