Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed Sometimes the dtype of the model is incorrect #301

Closed
wants to merge 7 commits into from

Conversation

balala8
Copy link

@balala8 balala8 commented Sep 1, 2024

  1. Fixed Sometimes the dtype of the model is incorrect, This bug occurs often when loading weights from a local file.
  2. Fixed the issue that the QuantizedTransformersModel model does not have a __call__ method, which caused an error when using the T5Encoder model in flux

What does this PR do?

  1. Fixes Sometimes the dtype of the model is incorrect,
  2. Fixed the issue that the QuantizedTransformersModel model does not have a call method, which caused an error when using the T5Encoder model in quantized flux

Before submitting

  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you run all tests locally and make sure they pass.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@balala8 balala8 requested a review from dacorvo as a code owner September 1, 2024 16:36
optimum/quanto/models/diffusers_models.py Outdated Show resolved Hide resolved
@@ -133,6 +137,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
with init_empty_weights():
model = cls.auto_class.from_config(config)
dtype = kwargs.get("torch_dtype", torch.float16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same thing here.

return self._wrapped.forward(*args, **kwargs)

def forward(self, *args, **kwargs):
return self.model.forward(*args, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This reverts the change I just made to fix a bug ...

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I will change it.

change default dtype from float16 to float32
change forward method
Copy link
Author

@balala8 balala8 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. change default dtype from float16 to float32

optimum/quanto/models/diffusers_models.py Outdated Show resolved Hide resolved
return self._wrapped.forward(*args, **kwargs)

def forward(self, *args, **kwargs):
return self.model.forward(*args, **kwargs)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I will change it.

Comment on lines +160 to +161
dtype = kwargs.get("torch_dtype", torch.float32)
model = model.to(dtype=dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is incorrect, as it will force model to torch.float32 if torch_dtype is not specified.
I think the correct fix would be:

if "torch_dtype" in kwargs:
    model = model.to(kwargs["torch_dtype")

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this parameter must be provided.. I quantized the flux transformer model and the t5encoder model. After saving them locally, when reloading, the bias of a convolutional layer in the transformer model was always in fp32, while the entire t5encoder model was in fp32, which did not reflect the acceleration advantage, the code can be run without error. Forcing the model to be fp16 can avoid this problem.
this is i using code.
quanto and save,

class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
    base_class = FluxTransformer2DModel

class QuantizedT5EncoderModelForCausalLM(QuantizedTransformersModel):
    auto_class = T5EncoderModel
    auto_class.from_config = auto_class._from_config
    
if __name__ == "__main__":
    bfl_repo = "black-forest-labs/FLUX.1-dev"
    dtype = torch.float16

    transformer = FluxTransformer2DModel.from_pretrained("black-forest-labs/FLUX.1-dev",subfolder="transformer", torch_dtype=dtype)
    q_transformer = QuantizedFluxTransformer2DModel.quantize(transformer, weights=qfloat8)
    q_transformer.save_pretrained("flux_transforemer_fp8_quanto")
    
    t5encoder = T5EncoderModel.from_pretrained("black-forest-labs/FLUX.1-dev",subfolder="text_encoder_2", torch_dtype=dtype)
    t5encoder = QuantizedT5EncoderModelForCausalLM.quantize(t5encoder, weights=qfloat8)
    t5encoder.save_pretrained("flux_T5Encoder_fp8_quanto")

here is load and inference code

import torch
from diffusers import FluxTransformer2DModel, FluxPipeline
from transformers import T5EncoderModel, CLIPTextModel
from optimum.quanto import freeze, qfloat8, quantize,QuantizedTransformersModel,QuantizedDiffusersModel
from huggingface_hub import login

bfl_repo = "black-forest-labs/FLUX.1-dev"
dtype = torch.float16

class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
    base_class = FluxTransformer2DModel

class QuantizedT5EncoderModelForCausalLM(QuantizedTransformersModel):
    auto_class = T5EncoderModel
    auto_class.from_config = auto_class._from_config

transformer = QuantizedFluxTransformer2DModel.from_pretrained("./flux_transforemer_fp8_quanto", torch_dtype=dtype)
transformer.to(device="cuda")
print("transformer.dtype:",transformer.dtype)
text_encoder_2 = QuantizedT5EncoderModelForCausalLM.from_pretrained("./flux_T5Encoder_fp8_quanto", torch_dtype=dtype)
text_encoder_2.to(device="cuda")  
print("text_encoder_2.dtype:",text_encoder_2.dtype)

pipe = FluxPipeline.from_pretrained(bfl_repo, 
                                    transformer=None,
                                    text_encoder_2=None,
                                    torch_dtype=dtype)
pipe.transformer = transformer
pipe.text_encoder_2 = text_encoder_2
pipe = pipe.to(device="cuda")

pipe.enable_model_cpu_offload()

prompt = "cookie monster, yarn art style"
image = pipe(
    prompt,
    guidance_scale=3.5,
    output_type="pil",
    num_inference_steps=20,
    generator=torch.Generator("cpu").manual_seed(0)
).images[0]
image.save("flux-fp8-dev.png")

Copy link

github-actions bot commented Oct 4, 2024

This PR is stale because it has been open 15 days with no activity. Remove stale label or comment or this will be closed in 5 days.

@github-actions github-actions bot added the Stale label Oct 4, 2024
Copy link

This PR was closed because it has been stalled for 5 days with no activity.

@github-actions github-actions bot closed this Oct 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants