Skip to content

Commit

Permalink
Enable support for Kueue for GKETPUJob (#623)
Browse files Browse the repository at this point in the history
samos123 authored Aug 13, 2024
1 parent 9797b53 commit 66704ea
Showing 1 changed file with 18 additions and 14 deletions.
32 changes: 18 additions & 14 deletions axlearn/cloud/gcp/job.py
Original file line number Diff line number Diff line change
@@ -313,13 +313,15 @@ class Config(GCPJob.Config):
gcsfuse_mount: Optional configs for the GCS FUSE sidecar and volume mount.
See `GCSFuseMount` for details.
enable_pre_provisioner: Whether to enable pre-provisioner.
queue: The Kueue LocalQueue to use. If not set, no queue is used.
"""

env_vars: Dict[str, str] = {}
namespace: str = "default"
gcsfuse_mount: Optional[GCSFuseMount] = None
# This config is made Optional for backwards compatibility. Default is False.
enable_pre_provisioner: Optional[bool] = None
queue: Optional[str] = None

@classmethod
def define_flags(cls, fv: flags.FlagValues):
@@ -333,6 +335,12 @@ def define_flags(cls, fv: flags.FlagValues):
"GCS FUSE mount spec in the format key=value.",
flag_values=fv,
)
flags.DEFINE_string(
"queue",
None,
"The name of the Kueue LocalQueue to use. If not set, no queue is used.",
flag_values=fv,
)

@classmethod
def from_flags(cls, fv: flags.FlagValues, **kwargs) -> Config:
@@ -610,15 +618,19 @@ def _build_jobset(self) -> Nested[Any]:
"""
cfg: TPUGKEJob.Config = self.config

annotations = {
# The exclusive topology annotation will ensure that all Pods will have affinity
# rules added that will ensure that they are fully scheduled on the same
# pod-slice node-pools.
"alpha.jobset.sigs.k8s.io/exclusive-topology": "cloud.google.com/gke-nodepool",
}
if cfg.queue:
annotations["kueue.x-k8s.io/queue-name"] = cfg.queue

return dict(
metadata=dict(
name=cfg.name,
annotations={
# The exclusive topology annotation will ensure that all Pods will have affinity
# rules added that will ensure that they are fully scheduled on the same
# pod-slice node-pools.
"alpha.jobset.sigs.k8s.io/exclusive-topology": "cloud.google.com/gke-nodepool",
},
annotations=annotations,
),
spec=dict(
failurePolicy=dict(maxRestarts=cfg.max_tries - 1),
@@ -668,23 +680,15 @@ class Config(GKEJob.Config):
Attributes:
accelerator: GPU configuration.
queue: The Kueue LocalQueue to use. If not set, no queue is used.
"""

accelerator: AcceleratorConfig = AcceleratorConfig()
queue: Optional[str] = None

@classmethod
def define_flags(cls, fv: flags.FlagValues):
super().define_flags(fv)
common_kwargs = dict(flag_values=fv, allow_override=True)
accelerator_flags(**common_kwargs)
flags.DEFINE_string(
"queue",
None,
"The name of the Kueue LocalQueue to use. If not set, no queue is used.",
**common_kwargs,
)

@classmethod
def from_flags(cls, fv: flags.FlagValues, **kwargs) -> Config:

0 comments on commit 66704ea

Please sign in to comment.