diff --git a/axlearn/cloud/gcp/job.py b/axlearn/cloud/gcp/job.py index 5e08ba64..b7a4b985 100644 --- a/axlearn/cloud/gcp/job.py +++ b/axlearn/cloud/gcp/job.py @@ -313,6 +313,7 @@ class Config(GCPJob.Config): gcsfuse_mount: Optional configs for the GCS FUSE sidecar and volume mount. See `GCSFuseMount` for details. enable_pre_provisioner: Whether to enable pre-provisioner. + queue: The Kueue LocalQueue to use. If not set, no queue is used. """ env_vars: Dict[str, str] = {} @@ -320,6 +321,7 @@ class Config(GCPJob.Config): gcsfuse_mount: Optional[GCSFuseMount] = None # This config is made Optional for backwards compatibility. Default is False. enable_pre_provisioner: Optional[bool] = None + queue: Optional[str] = None @classmethod def define_flags(cls, fv: flags.FlagValues): @@ -333,6 +335,12 @@ def define_flags(cls, fv: flags.FlagValues): "GCS FUSE mount spec in the format key=value.", flag_values=fv, ) + flags.DEFINE_string( + "queue", + None, + "The name of the Kueue LocalQueue to use. If not set, no queue is used.", + flag_values=fv, + ) @classmethod def from_flags(cls, fv: flags.FlagValues, **kwargs) -> Config: @@ -610,15 +618,19 @@ def _build_jobset(self) -> Nested[Any]: """ cfg: TPUGKEJob.Config = self.config + annotations = { + # The exclusive topology annotation will ensure that all Pods will have affinity + # rules added that will ensure that they are fully scheduled on the same + # pod-slice node-pools. + "alpha.jobset.sigs.k8s.io/exclusive-topology": "cloud.google.com/gke-nodepool", + } + if cfg.queue: + annotations["kueue.x-k8s.io/queue-name"] = cfg.queue + return dict( metadata=dict( name=cfg.name, - annotations={ - # The exclusive topology annotation will ensure that all Pods will have affinity - # rules added that will ensure that they are fully scheduled on the same - # pod-slice node-pools. - "alpha.jobset.sigs.k8s.io/exclusive-topology": "cloud.google.com/gke-nodepool", - }, + annotations=annotations, ), spec=dict( failurePolicy=dict(maxRestarts=cfg.max_tries - 1), @@ -668,23 +680,15 @@ class Config(GKEJob.Config): Attributes: accelerator: GPU configuration. - queue: The Kueue LocalQueue to use. If not set, no queue is used. """ accelerator: AcceleratorConfig = AcceleratorConfig() - queue: Optional[str] = None @classmethod def define_flags(cls, fv: flags.FlagValues): super().define_flags(fv) common_kwargs = dict(flag_values=fv, allow_override=True) accelerator_flags(**common_kwargs) - flags.DEFINE_string( - "queue", - None, - "The name of the Kueue LocalQueue to use. If not set, no queue is used.", - **common_kwargs, - ) @classmethod def from_flags(cls, fv: flags.FlagValues, **kwargs) -> Config: