-
Notifications
You must be signed in to change notification settings - Fork 69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixed Sometimes the dtype of the model is incorrect #301
Conversation
@@ -133,6 +137,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], | |||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) | |||
with init_empty_weights(): | |||
model = cls.auto_class.from_config(config) | |||
dtype = kwargs.get("torch_dtype", torch.float16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same thing here.
return self._wrapped.forward(*args, **kwargs) | ||
|
||
def forward(self, *args, **kwargs): | ||
return self.model.forward(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This reverts the change I just made to fix a bug ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I will change it.
change default dtype from float16 to float32
change forward method
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- change default dtype from float16 to float32
return self._wrapped.forward(*args, **kwargs) | ||
|
||
def forward(self, *args, **kwargs): | ||
return self.model.forward(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I will change it.
dtype = kwargs.get("torch_dtype", torch.float32) | ||
model = model.to(dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is incorrect, as it will force model to torch.float32
if torch_dtype
is not specified.
I think the correct fix would be:
if "torch_dtype" in kwargs:
model = model.to(kwargs["torch_dtype")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this parameter must be provided.. I quantized the flux transformer model and the t5encoder model. After saving them locally, when reloading, the bias of a convolutional layer in the transformer model was always in fp32, while the entire t5encoder model was in fp32, which did not reflect the acceleration advantage, the code can be run without error. Forcing the model to be fp16 can avoid this problem.
this is i using code.
quanto and save,
class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
base_class = FluxTransformer2DModel
class QuantizedT5EncoderModelForCausalLM(QuantizedTransformersModel):
auto_class = T5EncoderModel
auto_class.from_config = auto_class._from_config
if __name__ == "__main__":
bfl_repo = "black-forest-labs/FLUX.1-dev"
dtype = torch.float16
transformer = FluxTransformer2DModel.from_pretrained("black-forest-labs/FLUX.1-dev",subfolder="transformer", torch_dtype=dtype)
q_transformer = QuantizedFluxTransformer2DModel.quantize(transformer, weights=qfloat8)
q_transformer.save_pretrained("flux_transforemer_fp8_quanto")
t5encoder = T5EncoderModel.from_pretrained("black-forest-labs/FLUX.1-dev",subfolder="text_encoder_2", torch_dtype=dtype)
t5encoder = QuantizedT5EncoderModelForCausalLM.quantize(t5encoder, weights=qfloat8)
t5encoder.save_pretrained("flux_T5Encoder_fp8_quanto")
here is load and inference code
import torch
from diffusers import FluxTransformer2DModel, FluxPipeline
from transformers import T5EncoderModel, CLIPTextModel
from optimum.quanto import freeze, qfloat8, quantize,QuantizedTransformersModel,QuantizedDiffusersModel
from huggingface_hub import login
bfl_repo = "black-forest-labs/FLUX.1-dev"
dtype = torch.float16
class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
base_class = FluxTransformer2DModel
class QuantizedT5EncoderModelForCausalLM(QuantizedTransformersModel):
auto_class = T5EncoderModel
auto_class.from_config = auto_class._from_config
transformer = QuantizedFluxTransformer2DModel.from_pretrained("./flux_transforemer_fp8_quanto", torch_dtype=dtype)
transformer.to(device="cuda")
print("transformer.dtype:",transformer.dtype)
text_encoder_2 = QuantizedT5EncoderModelForCausalLM.from_pretrained("./flux_T5Encoder_fp8_quanto", torch_dtype=dtype)
text_encoder_2.to(device="cuda")
print("text_encoder_2.dtype:",text_encoder_2.dtype)
pipe = FluxPipeline.from_pretrained(bfl_repo,
transformer=None,
text_encoder_2=None,
torch_dtype=dtype)
pipe.transformer = transformer
pipe.text_encoder_2 = text_encoder_2
pipe = pipe.to(device="cuda")
pipe.enable_model_cpu_offload()
prompt = "cookie monster, yarn art style"
image = pipe(
prompt,
guidance_scale=3.5,
output_type="pil",
num_inference_steps=20,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]
image.save("flux-fp8-dev.png")
This PR is stale because it has been open 15 days with no activity. Remove stale label or comment or this will be closed in 5 days. |
This PR was closed because it has been stalled for 5 days with no activity. |
What does this PR do?
Before submitting
Pull Request section?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.