-
Notifications
You must be signed in to change notification settings - Fork 154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Third-party benchmark #6
Comments
Thank you for sharing! have you checked accuracy benchmarks too? |
@samuelazran nope, but the loss curve is pretty good for me |
|
Hi @hiyouga, I am trying out GaLore with this repo. However, I am experiencing a very low throughput on an A6000. How did you manage to make it >1it/s? In addition, if I understand correctly, GaGlore reduces O(N) operations (element-wise scaling) but adds more O(N^3) operations (SVD and projections) upon Adam-8bit, how is it faster instead? |
@yongchanghao Sorry, we might miss some experimental details. We used |
@hiyouga I'm also confused why GaLore can improve throughput without increasing batch_size. Actually, in the paper it mentioned "which |
@pkumc The previous results were somewhat unfair indeed. Now we have adjusted the experimental setup and updated the results. When the rank is small (<128), GaLore still has better throughput. I guess it may be because GaLore has fewer FLOPs in training. Regarding the data reported in the paper, we have discussed it with the author, and it may be due to different hardware which has varied GEMM performance. |
@hiyouga Thanks for the update. I feel the current data make more sense. For future readers' reference, my preliminary experience aligns well with the data reported in in #3 (comment) |
I believe you need do galore layer by layer in order to save memory, as in Line 334 in a6bc165
|
This is not a formal research. Although Galore reduces the amount of memory used, it is undeniable that Galore increases the training time by a factor of three. The increase in time is not friendly to LLM training. This is test code: '''
#install
conda create --name test python=3.11
conda activate test
export CUDA_HOME=xxxxxxx
export LD_LIBRARY_PATH=$CUDA_HOME"/lib64:$LD_LIBRARY_PATH"
export PATH=$CUDA_HOME"/bin:$PATH"
pip install -U transformers trl datasets
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
pip install galore-torch
HF support optimizer
['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_torch_npu_fused', 'adamw_apex_fused', 'adafactor', 'adamw_anyprecision', 'sgd', 'adagrad', 'adamw_bnb_8bit', 'adamw_8bit', 'lion_8bit', 'lion_32bit',
'paged_adamw_32bit', 'paged_adamw_8bit', 'paged_lion_32bit', 'paged_lion_8bit', 'rmsprop', 'rmsprop_bnb', 'rmsprop_bnb_8bit', 'rmsprop_bnb_32bit',
'galore_adamw', 'galore_adamw_8bit', 'galore_adafactor',
'galore_adamw_layerwise', 'galore_adamw_8bit_layerwise', 'galore_adafactor_layerwise']
'''
import torch
import datasets
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
import trl, time
train_dataset = datasets.load_dataset('imdb', split='train')
args = TrainingArguments(
output_dir="./test-galore",
max_steps=100,
per_device_train_batch_size=2,
optim="adamw_hf",
optim_target_modules=["attn", "mlp"]
)
model_id = "Qwen/Qwen1.5-0.5B"
#model_id = "Qwen/Qwen1.5-4B"
#model_id = "Qwen/Qwen1.5-7B"
#model_id = "mistralai/Mistral-7B-v0.1"
config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(0)
trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=512,
)
start_time = time.time()
trainer.train()
train_time = time.time()-start_time
print(f"=====================================================")
print(f"Time Used: {train_time:.2f} s")
print(f"memory_allocated: {torch.cuda.memory_allocated()/1024.0/1024.0:.2f} MB")
print(f"max_memory_allocated: {torch.cuda.max_memory_allocated()/1024.0/1024.0:.2f} MB")
print(f"memory_reserved: {torch.cuda.memory_reserved()/1024.0/1024.0:.2f} MB")
print(f"max_memory_reserved: {torch.cuda.max_memory_reserved()/1024.0/1024.0:.2f} MB")
print(f"free memory: {torch.cuda.mem_get_info()[0]/1024.0/1024.0:.2f} MB")
print(f"=====================================================") |
Thanks for providing your results @WangRongsheng . We are working on efficiency optimization and you can expect a big throughput boost in the next version. For train_loss, did you tune lr for GaLore? |
I will do it. |
Hello, thank you very much for such excellent work. We have conducted some experiments using Llama-Factory, and the results indicate that Galore can significantly reduce memory usage during full parameter fine-tuning. We utilized the 8-bit AdamW optimizer and pure bfloat16 training with gradient checkpointing. Galore requires only 18GB of VRAM to train a Llama-2 7B model, while the standard 8-bit AdamW optimizer requires at least 40GB of VRAM. We provide reproducible scripts for SFT training here: https://github.com/hiyouga/LLaMA-Factory/blob/main/examples/extras/galore/galore_adamw_8bit_bf16.sh
* We omitted the time of computing SVD for GaLore every
update_proj_gap
step, it costs around 10 minutes for a 7B model.Experiment results last updated: Mar 9th.
todo: add loss convergence results.
The text was updated successfully, but these errors were encountered: