Skip to content

Commit

Permalink
use cpu platform
Browse files Browse the repository at this point in the history
  • Loading branch information
hanzhi713 committed Nov 19, 2024
1 parent dca992f commit c806456
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions axlearn/common/checkpointer_orbax_emergency.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
import copy
import functools
import hashlib
import multiprocessing as mp
import os
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from multiprocessing import Process
from typing import Any, Dict, List, Optional, Tuple, Union

import jax
Expand Down Expand Up @@ -418,7 +418,10 @@ def get_consistent_proc_info(
used as the global coordinator address.
"""
start_t = time.perf_counter()
proc = Process(
platform = os.environ.get("JAX_PLATFORMS", "")
# Patch platform so the process doesn't waste time initializing accelerators.
os.environ["JAX_PLATFORMS"] = "cpu"
proc = mp.get_context("spawn").Process(
target=_init_consistent_proc_ids,
kwargs=dict(
local_address=local_address,
Expand All @@ -431,6 +434,9 @@ def get_consistent_proc_info(
proc.start()
proc.join()
assert proc.exitcode == 0
# Restore previous platform settings.
if platform != "":
os.environ["JAX_PLATFORMS"] = platform

info = _get_previous_process_info(local_ckpt_dir, unique_str=trainer_dir)
assert info.inv_proc_id != -1
Expand Down

0 comments on commit c806456

Please sign in to comment.