forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinit.cpp
238 lines (203 loc) · 8.23 KB
/
init.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
#include <torch/csrc/autograd/python_cpp_function.h>
#include <torch/csrc/distributed/autograd/autograd.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/types.h>
namespace torch {
namespace distributed {
namespace autograd {
namespace {
template <typename T>
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
PyObject* dist_autograd_init(PyObject* _unused, PyObject* noargs) {
auto autograd_module =
THPObjectPtr(PyImport_ImportModule("torch.distributed.autograd"));
if (!autograd_module) {
throw python_error();
}
auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
if (!torch_C_module) {
throw python_error();
}
auto torch_C_m = py::handle(torch_C_module).cast<py::module>();
auto m = torch_C_m.def_submodule(
"_distributed_autograd", "distributed autograd bindings");
auto module = py::handle(m).cast<py::module>();
auto distAutogradContext =
shared_ptr_class_<DistAutogradContext>(module, "DistAutogradContext")
.def(
"_context_id",
&DistAutogradContext::contextId,
py::call_guard<py::gil_scoped_release>())
.def(
"_recv_functions",
[](const DistAutogradContext& ctx) {
std::map<int64_t, py::object> funcs;
auto recvFunctions = ctx.recvFunctions();
// Acquire GIL only when necessary to avoid deadlocks.
pybind11::gil_scoped_acquire ag;
for (const auto& map_entry : recvFunctions) {
funcs.emplace(
map_entry.first,
py::reinterpret_steal<py::object>(
torch::autograd::functionToPyObject(
map_entry.second)));
}
return funcs;
},
py::call_guard<py::gil_scoped_release>())
.def(
"_send_functions",
[](const ContextPtr& ctx) {
std::map<int64_t, py::object> funcs;
auto sendFunctions = ctx->sendFunctions();
// Acquire GIL only when necessary to avoid deadlocks.
pybind11::gil_scoped_acquire ag;
for (const auto& map_entry : sendFunctions) {
funcs.emplace(
map_entry.first,
py::reinterpret_steal<py::object>(
torch::autograd::functionToPyObject(
map_entry.second)));
}
return funcs;
},
py::call_guard<py::gil_scoped_release>())
.def(
"_known_worker_ids",
&DistAutogradContext::getKnownWorkerIds,
py::call_guard<py::gil_scoped_release>());
module.def(
"_new_context",
[]() -> const ContextPtr {
return DistAutogradContainer::getInstance().newContext();
},
py::return_value_policy::reference,
py::call_guard<py::gil_scoped_release>());
module.def(
"_release_context",
[](int64_t context_id) {
return DistAutogradContainer::getInstance().releaseContext(context_id);
},
py::call_guard<py::gil_scoped_release>());
module.def(
"_get_max_id",
[]() { return DistAutogradContainer::getInstance().getMaxId(); },
py::call_guard<py::gil_scoped_release>());
module.def(
"_is_valid_context",
[](int64_t worker_id) {
DistAutogradContainer::getInstance().isValidContext(worker_id);
},
py::call_guard<py::gil_scoped_release>());
module.def(
"_retrieve_context",
[](int64_t context_id) -> const ContextPtr {
return DistAutogradContainer::getInstance().retrieveContext(context_id);
},
py::return_value_policy::reference,
py::call_guard<py::gil_scoped_release>());
module.def(
"_current_context",
[]() -> const ContextPtr {
return DistAutogradContainer::getInstance().currentContext();
},
py::return_value_policy::reference,
py::call_guard<py::gil_scoped_release>());
module.def(
"_init",
[](int64_t worker_id) { DistAutogradContainer::init(worker_id); },
py::call_guard<py::gil_scoped_release>());
module.def(
"_get_debug_info",
[]() { return DistEngine::getInstance().getDebugInfo(); },
py::call_guard<py::gil_scoped_release>());
py::options options;
options.disable_function_signatures();
module.def(
"backward",
backward,
R"(
backward(context_id: int, roots: List[Tensor], retain_graph = False) -> None
Kicks off the distributed backward pass using the provided roots. This
currently implements the :ref:`fast-mode-algorithm` which
assumes all RPC messages sent in the same distributed autograd context
across workers would be part of the autograd graph during the backward pass.
We use the provided roots to discover the autograd graph and compute
appropriate dependencies. This method blocks until the entire
autograd computation is done.
We accumulate the gradients in the appropriate
:class:`torch.distributed.autograd.context` on each of the nodes. The autograd
context to be used is looked up given the ``context_id`` that is passed in when
:meth:`torch.distributed.autograd.backward` is called. If there is no valid
autograd context corresponding to the given ID, we throw an error. You can
retrieve the accumulated gradients using the
:meth:`~torch.distributed.autograd.get_gradients` API.
Arguments:
context_id (int): The autograd context id for which we should retrieve the gradients.
roots (list): Tensors which represent the roots of the autograd
computation. All the tensors should be scalars.
retain_graph(bool, optional): If False, the graph used to compute the grad
will be freed. Note that in nearly all cases setting this
option to True is not needed and often can be worked around
in a much more efficient way. Usually, you need to set this
to True to run backward multiple times.
Example::
>>> import torch.distributed.autograd as dist_autograd
>>> with dist_autograd.context() as context_id:
>>> pred = model.forward()
>>> loss = loss_func(pred, loss)
>>> dist_autograd.backward(context_id, loss)
)",
py::arg("contextId"),
py::arg("roots"),
py::arg("retain_graph") = false,
py::call_guard<py::gil_scoped_release>());
module.def(
"get_gradients",
[](int64_t contextId) -> py::dict {
const auto& autogradContext =
DistAutogradContainer::getInstance().retrieveContext(contextId);
auto ival = IValue(autogradContext->getGradients());
// Acquire GIL only for pyobject conversion.
pybind11::gil_scoped_acquire ag;
return torch::jit::toPyObject(ival);
},
R"(
get_gradients(context_id: int) -> Dict[Tensor, Tensor]
Retrieves a map from Tensor to the appropriate gradient for that Tensor
accumulated in the provided context corresponding to the given ``context_id``
as part of the distributed autograd backward pass.
Arguments:
context_id(int): The autograd context id for which we should retrieve the
gradients.
Returns:
A map where the key is the Tensor and the value is the associated gradient
for that Tensor.
Example::
>>> import torch.distributed.autograd as dist_autograd
>>> with dist_autograd.context() as context_id:
>>> t1 = torch.rand((3, 3), requires_grad=True)
>>> t2 = torch.rand((3, 3), requires_grad=True)
>>> loss = t1 + t2
>>> dist_autograd.backward(context_id, [loss.sum()])
>>> grads = dist_autograd.get_gradients(context_id)
>>> print(grads[t1])
>>> print(grads[t2])
)",
py::arg("context_id"),
py::call_guard<py::gil_scoped_release>());
Py_RETURN_TRUE;
}
} // namespace
static PyMethodDef methods[] = { // NOLINT
{"_dist_autograd_init", dist_autograd_init, METH_NOARGS, nullptr},
{nullptr, nullptr, 0, nullptr}};
PyMethodDef* python_functions() {
return methods;
}
} // namespace autograd
} // namespace distributed
} // namespace torch