Skip to content

Commit

Permalink
[Pallas support] Add quant_blockwisely function, which quantizes te…
Browse files Browse the repository at this point in the history
…nsor according to the BlockSpec of the inputs.

PiperOrigin-RevId: 626918053
  • Loading branch information
lenscloth authored and copybara-github committed Apr 25, 2024
1 parent 6bb79d1 commit 50ca5d3
Show file tree
Hide file tree
Showing 3 changed files with 344 additions and 2 deletions.
164 changes: 164 additions & 0 deletions aqt/jax/v2/aqt_pallas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""AQT for pallas."""

import dataclasses
from typing import Sequence
from aqt.jax.v2 import aqt_tensor
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp


QTensor = aqt_tensor.QTensor


def _pop_and_append(l, p):
"""Pop l[p] and append at the back."""
if isinstance(l, list):
e = l.pop(p)
l.append(e)
elif isinstance(l, tuple):
e = l[p]
l = (*l[:p], e, *l[p+1:])
return l


def quant_blockwisely(
x: jax.Array,
n_bits: int,
calibration_axes: Sequence[int],
block_spec: pl.BlockSpec,
) -> tuple[QTensor, QTensor]:
"""Quantize x block-wisely according to block_spec.
x is quantized block-wisely (a.k.a subchannel) on the calibration axes, and
the size of block of each axis is determined by block_spec.block_shape[axis]
Args:
x: input tensor
n_bits: the precision for quantization.
calibration_axes: the calibration axes.
block_spec: Pallas BlockSpec of the input x
Returns:
A tuple of QTensor and block spec of that QTensor.
"""

if n_bits not in [4, 8]:
raise ValueError('n_bits must be either 4 or 8')

# TODO(wppark): use aqt_quantizer.Quantizer instead of code written from
# scratch.
tiled_x_shape = []
for axis, ndim in enumerate(x.shape):
if axis in calibration_axes:
tiled_x_shape += [
ndim // block_spec.block_shape[axis],
block_spec.block_shape[axis],
]
else:
tiled_x_shape += [ndim]

tiled_x = jnp.reshape(x, tiled_x_shape)
tiled_calibration_axes = [
(i + 1) + idx for i, idx in enumerate(calibration_axes)
]

abs_max = jnp.max(
jnp.abs(tiled_x), axis=tiled_calibration_axes, keepdims=True
)
tiled_scale = abs_max / (2 ** (n_bits - 1) - 1)

tiled_qx = jax.lax.round(
tiled_x / tiled_scale, jax.lax.RoundingMethod.TO_NEAREST_EVEN
)
tiled_qx = tiled_qx.astype(jnp.int8 if n_bits == 8 else jnp.int4)
tiled_qx = jnp.reshape(tiled_qx, x.shape)

qvalue = jnp.reshape(tiled_qx, x.shape)
scale = jnp.squeeze(tiled_scale, axis=tiled_calibration_axes)

scale_block_shape = tuple([
1 if axis in calibration_axes else ndim
for axis, ndim in enumerate(block_spec.block_shape)
])

# transpose scale such that:
# - the size of last dimension should be bigger 128.
# - the size second last dimension is 1.

# find the inner most dimension that its size is multiples of 128.
large_dim = 0
for axis, ndim in enumerate(scale.shape):
if ndim >= 128 and ndim % 128 == 0:
large_dim = axis

scale_permute_axis = list(range(scale.ndim))
# make large dim as the last dimension
scale_permute_axis = _pop_and_append(scale_permute_axis, large_dim)

# transpose scale and its block shape accordingly
scale = jnp.transpose(scale, scale_permute_axis)
scale_block_shape = [scale_block_shape[ax] for ax in scale_permute_axis]

# make the size of second last dimension to be 1
is_expand_dims = scale.shape[-2] != 1
if is_expand_dims:
scale = jnp.expand_dims(scale, axis=-2)
scale_permute_axis.insert(len(scale_permute_axis) - 1, -1)
scale_block_shape = (*scale_block_shape[:-1], 1, scale_block_shape[-1])

def scale_index_map(*args):
index = block_spec.index_map(*args)
index = _pop_and_append(index, large_dim)
if is_expand_dims:
index = (*index[:-1], 0, index[-1])
return index

scale_block_spec = pl.BlockSpec(
index_map=scale_index_map,
block_shape=scale_block_shape,
)
qx = QTensor(
qvalue=qvalue,
scale=[scale],
scale_t=None,
dequant_dtype=scale.dtype,
scale_permute_axis=[scale_permute_axis],
)

qx_block_spec = dataclasses.replace(
qx,
qvalue=block_spec,
scale=[scale_block_spec],
)
return qx, qx_block_spec


def materialize_qtensor(qtensor: QTensor) -> QTensor:
"""Materialize QTensor of MemoryRef of pallas into QTensor of jax.Array."""
qvalue = qtensor.qvalue
scale = qtensor.scale
scale_t = qtensor.scale_t

if qvalue is not None:
qvalue = qvalue[...]
if scale is not None:
scale = [s[...] for s in scale]
if scale_t is not None:
scale_t = [st[...] for st in scale_t]

return qtensor.replace(qvalue=qvalue, scale=scale, scale_t=scale_t)
133 changes: 133 additions & 0 deletions aqt/jax/v2/aqt_pallas_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test for AQT pallas."""

import functools

from absl.testing import absltest
from absl.testing import parameterized
from aqt.jax.v2 import aqt_pallas
from aqt.jax.v2 import aqt_tensor


import jax
from jax.experimental import pallas as pl

import jax.numpy as jnp
import numpy as np


class AqtPallasTest(parameterized.TestCase):

@parameterized.parameters(
((512, 512), (0,), (128, 128), (4, 1, 512), (1, 1, 128), [0, -1, 1]),
((512, 512), (1,), (128, 128), (4, 1, 512), (1, 1, 128), [1, -1, 0]),
(
(512, 512, 1024),
(1, 2),
(128, 128, 128),
(4, 8, 1, 512),
(1, 1, 1, 128),
[1, 2, -1, 0],
),
)
def test_quant_blockwisely_correctness(
self,
tensor_shape,
calibration_axes,
block_shape,
expected_scale_shape,
expected_scale_block_shape,
expected_scale_permute_axis,
):
"""Test whether QTenor can be used as an argument in pallas."""
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, tensor_shape)
block_spec = pl.BlockSpec(lambda *args: args, block_shape)

qx, qx_blockspec = aqt_pallas.quant_blockwisely(
x, n_bits=8, calibration_axes=calibration_axes, block_spec=block_spec
)
self.assertEqual(qx.qvalue.shape, x.shape)
self.assertEqual(qx.scale[0].shape, expected_scale_shape)
self.assertEqual(qx.scale_permute_axis[0], expected_scale_permute_axis)
self.assertIsNone(qx.scale_t)
self.assertEqual(
qx_blockspec.scale[0].block_shape, expected_scale_block_shape
)

@parameterized.parameters(
(
(1024, 1024),
(1,),
(256, 256),
),
(
(1024, 1024),
(0,),
(256, 256),
),
(
(10, 512, 1024),
(1,),
(1, 256, 256),
),
(
(10, 512, 1024),
(2,),
(1, 256, 256),
),
)
def test_quant_dequant(
self, tensor_shape, calibration_axes, block_shape
):
"""Test whether QTenor can be used as an argument in pallas."""
key = jax.random.PRNGKey(0)
x = jax.random.uniform(key, tensor_shape, minval=-3, maxval=3)
block_spec = pl.BlockSpec(lambda *args: args, block_shape)

@functools.partial(jax.jit, static_argnames=["block_spec"])
def quant_dequant(x, block_spec):
qx, qx_blockspec = aqt_pallas.quant_blockwisely(
x,
n_bits=8,
calibration_axes=calibration_axes,
block_spec=block_spec,
)
grid = [
ndim // blk_ndim for ndim, blk_ndim in zip(tensor_shape, block_shape)
]

def dequant_kernel(qx: aqt_tensor.QTensor, out_ref):
qx = aqt_pallas.materialize_qtensor(qx)
out_ref[...] = qx.dequant()

dequant_out = pl.pallas_call(
dequant_kernel,
grid=tuple(grid),
in_specs=[qx_blockspec],
out_specs=block_spec,
out_shape=jax.ShapeDtypeStruct(shape=tensor_shape, dtype=jnp.float32),
interpret=False
)(qx)
return dequant_out

np.testing.assert_array_almost_equal(
quant_dequant(x, block_spec), x, decimal=1
)


if __name__ == "__main__":
absltest.main()
49 changes: 47 additions & 2 deletions aqt/jax/v2/aqt_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,29 @@
ArrayT: TypeAlias = jnp.ndarray


def _restore_from_permutation(
scale: jax.Array, permute_axis: list[utils.AxisIdx]
) -> jax.Array:
"""Restores the scale from permutation and expansion."""
expanded_axis = []
transpose_axis = []
for i, axis in enumerate(permute_axis):
if axis == -1:
expanded_axis.append(i)
else:
transpose_axis.append(axis)

scale = jnp.squeeze(scale, expanded_axis)
# If the size of tensor is equivalent to the size of one axis of tensor, then
# the tensor is reshapable.
reshable = max(scale.shape) == scale.size
if reshable:
scale = jnp.reshape(scale, (scale.shape[i] for i in transpose_axis))
else:
scale = jax.lax.transpose(scale, transpose_axis)
return scale


@utils.flax_slots_dataclass
class QTensor:
"""Quantized tensor."""
Expand All @@ -74,6 +97,22 @@ class QTensor:
pytree_node=False, default=None
)

# The permutation and expansion applied on each scale factor.
# -1 indicates an expanded dim.
#
# This is needed for pallas. The shape of tensor provided as an argument of
# pl.pallas_call has following constraints:
# - the size of the last dimension of tensor should be bigger than 128.
# - the size of the last two dimension should be either
# - bigger than (8, 128)
# - or (1, multiples of 128)
# To meets the constraints, scales are transposed and expanded. An axis bigger
# than 128 sent back to the last dimension, and the second last dimension is
# expanded. The scale is restored during dequantization.
scale_permute_axis: Optional[list[list[utils.AxisIdx]]] = flax.struct.field(
pytree_node=False, default=None
)

def is_full(self) -> bool:
return self.qvalue is not None

Expand All @@ -82,11 +121,15 @@ def without_qvalue(self) -> Self:
return self.replace(qvalue=None) # pytype: disable=attribute-error

def quant(self, x):
"""Quantizes the QTensor."""

assert not self.is_full(), 'Already quantized QTensor.'
assert self.scale is not None, 'Missing scales to be used for quantization.'

qvalue = x
for s in self.scale:
for i, s in enumerate(self.scale):
if self.scale_permute_axis is not None:
s = _restore_from_permutation(s, self.scale_permute_axis[i])
qvalue = qvalue * jax.lax.reciprocal(s)

# TODO(lew): We should apply numerics here, so that 'quant' function
Expand All @@ -103,7 +146,9 @@ def dequant(self) -> jnp.ndarray:
assert self.dequant_dtype is not None, msg
assert self.is_full(), _MSG_NO_QVALUE
ret = self.qvalue
for scale in self.scale:
for i, scale in enumerate(self.scale):
if self.scale_permute_axis is not None:
scale = _restore_from_permutation(scale, self.scale_permute_axis[i])
ret = ret.astype(self.dequant_dtype) * scale.astype(self.dequant_dtype) # pytype: disable=attribute-error
return ret # pytype: disable=bad-return-type

Expand Down

0 comments on commit 50ca5d3

Please sign in to comment.