diff --git a/optimum/quanto/nn/qmodule.py b/optimum/quanto/nn/qmodule.py index 152f9d17..932d4ac4 100644 --- a/optimum/quanto/nn/qmodule.py +++ b/optimum/quanto/nn/qmodule.py @@ -217,7 +217,8 @@ def from_module( if qmodule is None: return None # Move the quantized module to the target device, but with empty weights - qmodule = qmodule.to_empty(device=module.weight.device) + device = torch.device("cpu") if module.weight is None else module.weight.device + qmodule = qmodule.to_empty(device=device) # Set scales that were initialized to empty values qmodule.input_scale = torch.ones_like(qmodule.input_scale) qmodule.output_scale = torch.ones_like(qmodule.output_scale) @@ -226,7 +227,7 @@ def from_module( if module.bias is not None: qmodule.bias = module.bias - return qmodule.to(module.weight.device) + return qmodule.to(device) @classmethod def qcreate( diff --git a/test/nn/test_qlayernorm.py b/test/nn/test_qlayernorm.py index 4a6913cc..371df0da 100644 --- a/test/nn/test_qlayernorm.py +++ b/test/nn/test_qlayernorm.py @@ -20,9 +20,9 @@ from optimum.quanto.nn import QLayerNorm -def _test_quantize_layernorm(batch_size, tokens, embeddings, dtype, activations, device): +def _test_quantize_layernorm(batch_size, tokens, embeddings, affine, dtype, activations, device): # Instantiate a normalization layer - norm = torch.nn.LayerNorm(embeddings).to(dtype).to(device) + norm = torch.nn.LayerNorm(embeddings, elementwise_affine=affine).to(dtype).to(device) qnorm = QLayerNorm.from_module(norm, activations=activations) qinputs = random_qactivation((batch_size,) + (tokens, embeddings), qtype=activations, dtype=dtype).to(device) # Calibrate to avoid clipping and to set the correct dtype @@ -43,38 +43,42 @@ def _test_quantize_layernorm(batch_size, tokens, embeddings, dtype, activations, @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("tokens, embeddings", [(32, 32), (10, 32)]) -def test_quantize_layernorm_float16_activations_int8(batch_size, tokens, embeddings, device): - _test_quantize_layernorm(batch_size, tokens, embeddings, torch.float16, qint8, device) +@pytest.mark.parametrize("affine", [True, False], ids=["affine", "non-affine"]) +def test_quantize_layernorm_float16_activations_int8(batch_size, tokens, embeddings, affine, device): + _test_quantize_layernorm(batch_size, tokens, embeddings, affine, torch.float16, qint8, device) @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("tokens, embeddings", [(32, 32), (10, 32)]) -def test_quantize_layernorm_float32_activations_int8(batch_size, tokens, embeddings, device): - _test_quantize_layernorm(batch_size, tokens, embeddings, torch.float32, qint8, device) +@pytest.mark.parametrize("affine", [True, False], ids=["affine", "non-affine"]) +def test_quantize_layernorm_float32_activations_int8(batch_size, tokens, embeddings, affine, device): + _test_quantize_layernorm(batch_size, tokens, embeddings, affine, torch.float32, qint8, device) @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("tokens, embeddings", [(32, 32), (10, 32)]) +@pytest.mark.parametrize("affine", [True, False], ids=["affine", "non-affine"]) @pytest.mark.parametrize( "activations", [qfloat8_e5m2, qfloat8_e4m3fn, qfloat8_e4m3fnuz], ids=["a-float8-e5m2", "a-float8-e4m3", "a-float8-e4m3-uz"], ) @pytest.mark.skip_device("mps") -def test_quantize_layernorm_float16_activations_float8(batch_size, tokens, embeddings, activations, device): - _test_quantize_layernorm(batch_size, tokens, embeddings, torch.float16, activations, device) +def test_quantize_layernorm_float16_activations_float8(batch_size, tokens, embeddings, affine, activations, device): + _test_quantize_layernorm(batch_size, tokens, embeddings, affine, torch.float16, activations, device) @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("tokens, embeddings", [(32, 32), (10, 32)]) +@pytest.mark.parametrize("affine", [True, False], ids=["affine", "non-affine"]) @pytest.mark.parametrize( "activations", [qfloat8_e5m2, qfloat8_e4m3fn, qfloat8_e4m3fnuz], ids=["a-float8-e5m2", "a-float8-e4m3", "a-float8-e4m3-uz"], ) @pytest.mark.skip_device("mps") -def test_quantize_layernorm_float32_activations_float8(batch_size, tokens, embeddings, activations, device): - _test_quantize_layernorm(batch_size, tokens, embeddings, torch.float32, activations, device) +def test_quantize_layernorm_float32_activations_float8(batch_size, tokens, embeddings, affine, activations, device): + _test_quantize_layernorm(batch_size, tokens, embeddings, affine, torch.float32, activations, device) def test_quantize_layernom_no_activation():