forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathc10d_logger.py
91 lines (70 loc) · 2.69 KB
/
c10d_logger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import functools
import logging
import time
from typing import Any, Dict, List, Tuple
import torch.distributed as dist
from torch.distributed.logging_handlers import _log_handlers
__all__: List[str] = []
def _get_or_create_logger() -> logging.Logger:
logging_handler, log_handler_name = _get_logging_handler()
logger = logging.getLogger(f"c10d-{log_handler_name}")
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter(
"%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s"
)
logging_handler.setFormatter(formatter)
logger.propagate = False
logger.addHandler(logging_handler)
return logger
def _get_logging_handler(destination: str = "default") -> Tuple[logging.Handler, str]:
log_handler = _log_handlers[destination]
log_handler_name = type(log_handler).__name__
return (log_handler, log_handler_name)
global _c10d_logger
_c10d_logger = _get_or_create_logger()
def _get_msg_dict(func_name, *args, **kwargs) -> Dict[str, Any]:
if dist.is_initialized():
msg_dict = {
"func_name": f"{func_name}",
"args": f"{args}, {kwargs}",
"pg_name": f"{dist._get_process_group_name(kwargs.get('pg'))}", # type: ignore[arg-type]
"backend": f"{dist.get_backend(kwargs.get('group'))}",
"world_size": f"{dist.get_world_size()}",
"group_size": f"{dist.get_world_size(kwargs.get('group'))}",
"global_rank": f"{dist.get_rank()}",
"local_rank": f"{dist.get_rank(kwargs.get('group'))}",
}
else:
msg_dict = {
"func_name": f"{func_name}",
"args": f"{args}, {kwargs}",
}
return msg_dict
def _exception_logger(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as error:
msg_dict = _get_msg_dict(func.__name__, *args, **kwargs)
msg_dict["error"] = f"{error}"
_c10d_logger.debug(msg_dict)
raise
return wrapper
def _time_logger(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
t1 = time.time_ns()
func_return = func(*args, **kwargs)
time_spent = time.time_ns() - t1
msg_dict = _get_msg_dict(func.__name__, *args, **kwargs)
msg_dict["time_spent"] = f"{time_spent}ns"
_c10d_logger.debug(msg_dict)
return func_return
return wrapper