-
Notifications
You must be signed in to change notification settings - Fork 154
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a6bc165
commit 33f2c4f
Showing
3 changed files
with
72 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,24 +9,80 @@ As a gradient projection method, GaLore is independent of the choice of optimize | |
<img src="imgs/galore_code_box.png" alt="Image 2" style="width: 550px; margin: 0 auto;"> | ||
</div> | ||
|
||
## News | ||
Thanks everyone for the interest in GaLore! | ||
|
||
**We are working on the offical release of GaLore.** In the meanwhile, please feel free to try the pre-release version and provide feedback to us. Currently, the pre-release version (e.g., GaLore optimizers) should provide a decent memory reduction and accurate simulation of GaLore algorithm. | ||
|
||
The official release of GaLore will include: | ||
|
||
1. Per-layer weight updates for multi-GPU training (DDP and FSDP) (working with [PyTorch](https://pytorch.org/)). | ||
2. Memory-efficient low-rank gradient accumulation (working with [PyTorch](https://pytorch.org/)). | ||
3. Optimized `GaLoreAdamW8bit` (working with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)). | ||
|
||
We would like to express our gratitude to the community members who have been actively working on integrating GaLore into different platforms, including [HuggingFace](https://github.com/huggingface/transformers/pull/29588), [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory), and [Axolotl](https://github.com/OpenAccess-AI-Collective/axolotl/pull/1370). Join our Slack workspace [GaLore-Social](https://join.slack.com/t/galore-social/shared_invite/zt-2ev152px0-DguuQ5WRTLQjtq2C88HBvQ) to engage in discussions with us. | ||
|
||
## Discussion [(GaLore-Social)](https://join.slack.com/t/galore-social/shared_invite/zt-2ev152px0-DguuQ5WRTLQjtq2C88HBvQ) | ||
|
||
We welcome any discussions, questions, and feedback on GaLore. Please join our Slack workspace [GaLore-Social](https://join.slack.com/t/galore-social/shared_invite/zt-2ev152px0-DguuQ5WRTLQjtq2C88HBvQ) to discuss with us and the community. | ||
|
||
|
||
## Installation | ||
|
||
### Install GaLore optimizer | ||
Install from pip: | ||
```bash | ||
pip install galore-torch | ||
``` | ||
|
||
or if you want to install from source: | ||
|
||
```bash | ||
git clone [email protected]:jiaweizzhao/GaLore.git | ||
cd GaLore | ||
pip install -e . | ||
``` | ||
|
||
### Install experiment dependencies | ||
|
||
```bash | ||
pip install -r exp_requirements.txt | ||
``` | ||
|
||
## Usage | ||
|
||
### Save optimizer memory using GaLore optimizers | ||
|
||
```python | ||
from galore_torch import GaLoreAdamW, GaLoreAdamW8bit, GaLoreAdafactor | ||
# define param groups as galore_params and non_galore_params | ||
param_groups = [{'params': non_galore_params}, | ||
{'params': galore_params, 'rank': 128, 'update_proj_gap': 200, 'scale': 0.25, 'proj_type': 'std'}] | ||
optimizer = GaLoreAdamW(param_groups, lr=0.01) | ||
``` | ||
### Save weight gradient memory using per-layer weight updates | ||
|
||
We use `register_post_accumulate_grad_hook` provided by [PyTorch](https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html) to enable per-layer weight updates. An example is shown below: | ||
|
||
```python | ||
# define an optimizer for each parameter p, and store them in optimizer_dict | ||
for p in model.parameters(): | ||
if p.requires_grad: | ||
optimizer_dict[p] = GaLoreAdamW([{'params': p, 'rank': 128, 'update_proj_gap': 200, 'scale': 0.25, 'proj_type': 'std'}], lr=0.01) | ||
|
||
# define a hook function to update the parameter p during the backward pass | ||
def optimizer_hook(p): | ||
if p.grad is None: | ||
return | ||
optimizer_dict[p].step() | ||
optimizer_dict[p].zero_grad() | ||
|
||
# Register the hook onto every parameter | ||
for p in model.parameters(): | ||
if p.requires_grad: | ||
p.register_post_accumulate_grad_hook(optimizer_hook) | ||
``` | ||
More details can be found in [torchrun_main.py](https://github.com/jiaweizzhao/GaLore/blob/a6bc1650984b1c090a4e108d7c0e3109ee7ad844/torchrun_main.py#L334). | ||
|
||
## Benchmark 1: Pre-Training LLaMA on C4 dataset | ||
`torchrun_main.py` is the main script for training LLaMA models on C4 with GaLore. Our benchmark scripts for various sizes of models are in `scripts/benchmark_c4` folder. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
torch | ||
transformers==4.31.0 | ||
tokenizers | ||
datasets | ||
peft | ||
wandb | ||
loguru | ||
nvitop | ||
lion-pytorch | ||
matplotlib | ||
bitsandbytes | ||
scipy | ||
scikit-learn | ||
evaluate |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,3 @@ | ||
torch | ||
transformers==4.31.0 | ||
tokenizers | ||
datasets | ||
peft | ||
wandb | ||
loguru | ||
nvitop | ||
lion-pytorch | ||
matplotlib | ||
bitsandbytes | ||
scipy | ||
scikit-learn | ||
evaluate | ||
transformers | ||
bitsandbytes |