forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinit.cpp
862 lines (788 loc) · 35.6 KB
/
init.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
#include <torch/csrc/python_headers.h>
#include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h>
#include <torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h>
#include <torch/csrc/distributed/rpc/py_rref.h>
#include <torch/csrc/distributed/rpc/python_functions.h>
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
#include <torch/csrc/distributed/rpc/request_callback_impl.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/distributed/rpc/rref_context.h>
#include <torch/csrc/distributed/rpc/tensorpipe_agent.h>
#include <torch/csrc/distributed/rpc/torchscript_functions.h>
#include <torch/csrc/distributed/rpc/types.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/types.h>
#include <pybind11/chrono.h>
#include <pybind11/operators.h>
namespace torch {
namespace distributed {
namespace rpc {
namespace {
constexpr std::chrono::milliseconds kDeleteAllUsersTimeout(100000);
template <typename T>
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
auto rpc_module =
THPObjectPtr(PyImport_ImportModule("torch.distributed.rpc"));
if (!rpc_module) {
throw python_error();
}
auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
if (!torch_C_module) {
throw python_error();
}
auto torch_C_m = py::handle(torch_C_module).cast<py::module>();
auto m =
torch_C_m.def_submodule("_distributed_rpc", "distributed rpc bindings");
auto module = py::handle(m).cast<py::module>();
auto rpcBackendOptions =
shared_ptr_class_<RpcBackendOptions>(
module,
"RpcBackendOptions",
R"(An abstract structure encapsulating the options passed into the RPC
backend. An instance of this class can be passed in to
:meth:`~torch.distributed.rpc.init_rpc` in order to initialize RPC
with specific configurations, such as the RPC timeout and
``init_method`` to be used. )")
.def(py::init<>())
.def(
py::init<float, std::string>(),
py::arg("rpc_timeout") = kDefaultRpcTimeoutSeconds,
py::arg("init_method") = kDefaultInitMethod)
.def_readwrite(
"rpc_timeout",
&RpcBackendOptions::rpcTimeoutSeconds,
R"(A float indicating the timeout to use for all
RPCs. If an RPC does not complete in this timeframe, it will
complete with an exception indicating that it has timed out.)")
.def_readwrite(
"init_method",
&RpcBackendOptions::initMethod,
R"(URL specifying how to initialize the process group.
Default is ``env://``)");
// The following C++ constants need to be cast so they can be used from
// python.
module.attr("_DEFAULT_RPC_TIMEOUT_SEC") = py::cast(kDefaultRpcTimeoutSeconds);
module.attr("_UNSET_RPC_TIMEOUT") = py::cast(kUnsetRpcTimeout);
module.attr("_DEFAULT_INIT_METHOD") = py::cast(kDefaultInitMethod);
auto workerInfo =
shared_ptr_class_<WorkerInfo>(
module,
"WorkerInfo",
R"(A structure that encapsulates information of a worker in the system.
Contains the name and ID of the worker. This class is not meant to
be constructed directly, rather, an instance can be retrieved
through :meth:`~torch.distributed.rpc.get_worker_info` and the
result can be passed in to functions such as
:meth:`~torch.distributed.rpc.rpc_sync`, :meth:`~torch.distributed.rpc.rpc_async`,
:meth:`~torch.distributed.rpc.remote` to avoid copying a string on
every invocation.)")
.def(
py::init<std::string, worker_id_t>(),
py::arg("name"),
py::arg("id"))
.def_readonly(
"name", &WorkerInfo::name_, R"(The name of the worker.)")
.def_readonly(
"id",
&WorkerInfo::id_,
R"(Globally unique id to identify the worker.)")
.def("__eq__", &WorkerInfo::operator==, py::is_operator())
// pybind11 suggests the syntax .def(hash(py::self)), with the
// unqualified "hash" function call. However the
// argument-dependent lookup for the function "hash" doesn't get
// triggered in this context because it conflicts with the struct
// c10::hash, so we need to use the qualified name
// py::detail::hash, which unfortunately is in a detail namespace.
.def(py::detail::hash(py::self)) // NOLINT
.def(
"__repr__",
[](const WorkerInfo& workerInfo) {
std::ostringstream os;
os << workerInfo;
return os.str();
})
.def(py::pickle(
/* __getstate__ */
[](const WorkerInfo& workerInfo) {
return py::make_tuple(workerInfo.name_, workerInfo.id_);
},
/* __setstate__ */
[](py::tuple t) {
TORCH_CHECK(t.size() == 2, "Invalid WorkerInfo state.");
WorkerInfo info(
t[0].cast<std::string>(), t[1].cast<worker_id_t>());
return info;
}));
auto rpcAgent =
shared_ptr_class_<RpcAgent>(module, "RpcAgent")
.def(
"join",
&RpcAgent::join,
py::call_guard<py::gil_scoped_release>(),
py::arg("shutdown") = false,
py::arg("timeout") = 0)
.def(
"sync", &RpcAgent::sync, py::call_guard<py::gil_scoped_release>())
.def(
"shutdown",
&RpcAgent::shutdown,
py::call_guard<py::gil_scoped_release>())
.def(
"get_worker_info",
(const WorkerInfo& (RpcAgent::*)(void) const) &
RpcAgent::getWorkerInfo,
py::call_guard<py::gil_scoped_release>())
.def(
"get_worker_info",
(const WorkerInfo& (RpcAgent::*)(const std::string&) const) &
RpcAgent::getWorkerInfo,
py::call_guard<py::gil_scoped_release>())
.def(
"get_worker_infos",
&RpcAgent::getWorkerInfos,
py::call_guard<py::gil_scoped_release>())
.def(
"_get_device_map",
&RpcAgent::getDeviceMap,
py::call_guard<py::gil_scoped_release>())
.def(
"get_debug_info",
&RpcAgent::getDebugInfo,
py::call_guard<py::gil_scoped_release>())
.def(
"get_metrics",
&RpcAgent::getMetrics,
py::call_guard<py::gil_scoped_release>());
auto pyRRef =
shared_ptr_class_<PyRRef>(module, "PyRRef", R"(
A class encapsulating a reference to a value of some type on a remote
worker. This handle will keep the referenced remote value alive on the
worker. A ``UserRRef`` will be deleted when 1) no references to it in
both the application code and in the local RRef context, or 2) the
application has called a graceful shutdown. Invoking methods on a
deleted RRef leads to undefined behaviors. RRef implementation only
offers best-effort error detection, and applications should not use
``UserRRefs`` after ``rpc.shutdown()``.
.. warning::
RRefs can only be serialized and deserialized by the RPC module.
Serializing and deserializing RRefs without RPC (e.g., Python
pickle, torch :meth:`~torch.save` / :meth:`~torch.load`,
JIT :meth:`~torch.jit.save` / :meth:`~torch.jit.load`, etc.) will
lead to errors.
Args:
value (object): The value to be wrapped by this RRef.
type_hint (Type, optional): Python type that should be passed to
``TorchScript`` compiler as type hint for ``value``.
Example::
Following examples skip RPC initialization and shutdown code
for simplicity. Refer to RPC docs for those details.
1. Create an RRef using rpc.remote
>>> import torch
>>> import torch.distributed.rpc as rpc
>>> rref = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
>>> # get a copy of value from the RRef
>>> x = rref.to_here()
2. Create an RRef from a local object
>>> import torch
>>> from torch.distributed.rpc import RRef
>>> x = torch.zeros(2, 2)
>>> rref = RRef(x)
3. Share an RRef with other workers
>>> # On both worker0 and worker1:
>>> def f(rref):
>>> return rref.to_here() + 1
>>> # On worker0:
>>> import torch
>>> import torch.distributed.rpc as rpc
>>> from torch.distributed.rpc import RRef
>>> rref = RRef(torch.zeros(2, 2))
>>> # the following RPC shares the rref with worker1, reference
>>> # count is automatically updated.
>>> rpc.rpc_sync("worker1", f, args=(rref,))
)")
.def(
py::init<const py::object&, const py::object&>(),
py::arg("value"),
py::arg("type_hint") = py::none())
.def(
// not releasing GIL here to avoid context switch on getters
"is_owner",
&PyRRef::isOwner,
R"(
Returns whether or not the current node is the owner of this
``RRef``.
)")
.def(
"confirmed_by_owner",
&PyRRef::confirmedByOwner,
R"(
Returns whether this ``RRef`` has been confirmed by the owner.
``OwnerRRef`` always returns true, while ``UserRRef`` only
returns true when the owner knowns about this ``UserRRef``.
)")
.def(
// not releasing GIL here to avoid context switch on getters
"owner",
&PyRRef::owner,
R"(
Returns worker information of the node that owns this ``RRef``.
)")
.def(
// not releasing GIL here to avoid context switch on getters
"owner_name",
&PyRRef::ownerName,
R"(
Returns worker name of the node that owns this ``RRef``.
)")
.def(
"to_here",
&PyRRef::toHere,
py::arg("timeout") = py::cast(kUnsetRpcTimeout),
py::call_guard<py::gil_scoped_release>(),
R"(
Blocking call that copies the value of the RRef from the owner
to the local node and returns it. If the current node is the
owner, returns a reference to the local value.
Args:
timeout (float, optional): Timeout for ``to_here``. If
the call does not complete within this timeframe, an
exception indicating so will be raised. If this
argument is not provided, the default RPC timeout
(60s) will be used.
)")
.def(
"local_value",
&PyRRef::localValue,
py::call_guard<py::gil_scoped_release>(),
R"(
If the current node is the owner, returns a reference to the
local value. Otherwise, throws an exception.
)")
.def(
"rpc_sync",
[](const PyRRef& self, float timeoutSeconds) {
return self.createRRefProxy(
RRefProxyType::RPC_SYNC, timeoutSeconds);
},
py::arg("timeout") = kUnsetRpcTimeout,
py::call_guard<py::gil_scoped_release>(),
R"(
Create a helper proxy to easily launch an ``rpc_sync`` using
the owner of the RRef as the destination to run functions on
the object referenced by this RRef. More specifically,
``rref.rpc_sync().func_name(*args, **kwargs)`` is the same as
the following:
>>> def run(rref, func_name, args, kwargs):
>>> return getattr(rref.local_value(), func_name)(*args, **kwargs)
>>>
>>> rpc.rpc_sync(rref.owner(), run, args=(rref, func_name, args, kwargs))
Args:
timeout (float, optional): Timeout for ``rref.rpc_sync()``.
If the call does not complete within this timeframe, an
exception indicating so will be raised. If this argument
is not provided, the default RPC timeout will be used.
Example::
>>> from torch.distributed import rpc
>>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
>>> rref.rpc_sync().size() # returns torch.Size([2, 2])
>>> rref.rpc_sync().view(1, 4) # returns tensor([[1., 1., 1., 1.]])
)")
.def(
"rpc_async",
[](const PyRRef& self, float timeoutSeconds) {
return self.createRRefProxy(
RRefProxyType::RPC_ASYNC, timeoutSeconds);
},
py::arg("timeout") = kUnsetRpcTimeout,
py::call_guard<py::gil_scoped_release>(),
R"(
Create a helper proxy to easily launch an ``rpc_async`` using
the owner of the RRef as the destination to run functions on
the object referenced by this RRef. More specifically,
``rref.rpc_async().func_name(*args, **kwargs)`` is the same as
the following:
>>> def run(rref, func_name, args, kwargs):
>>> return getattr(rref.local_value(), func_name)(*args, **kwargs)
>>>
>>> rpc.rpc_async(rref.owner(), run, args=(rref, func_name, args, kwargs))
Args:
timeout (float, optional): Timeout for ``rref.rpc_async()``.
If the call does not complete within this timeframe, an
exception indicating so will be raised. If this argument
is not provided, the default RPC timeout will be used.
Example::
>>> from torch.distributed import rpc
>>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
>>> rref.rpc_async().size().wait() # returns torch.Size([2, 2])
>>> rref.rpc_async().view(1, 4).wait() # returns tensor([[1., 1., 1., 1.]])
)")
.def(
"remote",
[](const PyRRef& self, float timeoutSeconds) {
return self.createRRefProxy(
RRefProxyType::REMOTE, timeoutSeconds);
},
py::arg("timeout") = kUnsetRpcTimeout,
py::call_guard<py::gil_scoped_release>(),
R"(
Create a helper proxy to easily launch a ``remote`` using
the owner of the RRef as the destination to run functions on
the object referenced by this RRef. More specifically,
``rref.remote().func_name(*args, **kwargs)`` is the same as
the following:
>>> def run(rref, func_name, args, kwargs):
>>> return getattr(rref.local_value(), func_name)(*args, **kwargs)
>>>
>>> rpc.remote(rref.owner(), run, args=(rref, func_name, args, kwargs))
Args:
timeout (float, optional): Timeout for ``rref.remote()``. If
the creation of this :class:`~torch.distributed.rpc.RRef`
is not successfully completed within the timeout, then the
next time there is an attempt to use the RRef
(such as ``to_here``), a timeout will be raised. If not
provided, the default RPC timeout will be used. Please see
``rpc.remote()`` for specific timeout semantics for
:class:`~torch.distributed.rpc.RRef`.
Example::
>>> from torch.distributed import rpc
>>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
>>> rref.remote().size().to_here() # returns torch.Size([2, 2])
>>> rref.remote().view(1, 4).to_here() # returns tensor([[1., 1., 1., 1.]])
)")
.def(
py::pickle(
/* __getstate__ */
[](const PyRRef& /* unused */) {
TORCH_CHECK(
false,
"Can not pickle rref in python pickler, rref can only be "
"pickled when using RPC");
// Note that this return has no meaning since we always
// throw, it's only here to satisfy Pybind API's
// requirement.
return py::make_tuple();
},
/* __setstate__ */
[](py::tuple /* unused */) { // NOLINT
TORCH_CHECK(
false,
"Can not unpickle rref in python pickler, rref can only be "
"unpickled when using RPC");
// Note that this return has no meaning since we always
// throw, it's only here to satisfy PyBind's API
// requirement.
return PyRRef(
py::cast<py::none>(Py_None),
py::cast<py::none>(Py_None));
}),
py::call_guard<py::gil_scoped_release>())
.def(
"_serialize",
&PyRRef::pickle,
py::call_guard<py::gil_scoped_release>())
.def_static(
"_deserialize",
&PyRRef::unpickle,
py::call_guard<py::gil_scoped_release>())
.def(
"_get_type",
// Intentionally not releasing GIL, as most accesses just
// retrieve cached type py::object
&PyRRef::getRRefType,
py::arg("timeout") = kUnsetRpcTimeout,
py::arg("blocking") = true,
R"(
If ``blocking=True``, returns the type of the data object
referenced by this ``RRef``. On the owner, this is same as
``type(rref.local_value())``. Otherwise, returns a future to
this result. On a user, this will trigger an RPC to fetch the
``type`` object from the owner. After this function is run
once, the ``type`` object is cached by the ``RRef``, and
subsequent invocations no longer trigger RPC. Note that this is
true regardless of the ``blocking`` argument of subsequent
calls.
Args:
rref (torch.distributed.rpc.RRef): The RRef to get type of.
timeout (float, optional): Timeout, in seconds for
``_get_type``. If the call does not complete within
this timeframe, an exception indicating so will be
raised. If this argument is not provided, the default
RPC timeout will be used.
blocking (bool, optional): Whether to synchronously wait on
the RPC triggered by the first call and return the
type. If ``False``, will return a future. Default is
``True``.
)")
.def(
"_get_future",
[](const PyRRef& self) {
return std::make_shared<jit::PythonFutureWrapper>(
self.getFuture());
},
py::call_guard<py::gil_scoped_release>(),
R"(
Returns the future that corresponds to the creation of this RRef
on the remote node. This is for internal use cases such as profiling
only.
)")
.def(
"_get_profiling_future",
[](const PyRRef& self) {
return std::make_shared<jit::PythonFutureWrapper>(
self.getProfilingFuture());
},
py::call_guard<py::gil_scoped_acquire>(),
R"(
Returns future that completes when the profiling event corresponding
to the creation of this RRef on the remote node has been recorded.
)")
.def(
"_set_profiling_future",
[](PyRRef& self,
const std::shared_ptr<jit::PythonFutureWrapper>&
wrappedFuture) {
self.setProfilingFuture(wrappedFuture->fut);
},
py::call_guard<py::gil_scoped_acquire>(),
R"(
Set future that is completed when the profiling event corresponding
to the creation of this RRef on the remote node has been recorded.
)")
.def(
"backward",
[](PyRRef& self,
int64_t dist_autograd_ctx_id,
bool retain_graph) {
self.backward(dist_autograd_ctx_id, retain_graph);
},
py::arg("dist_autograd_ctx_id") = -1,
py::arg("retain_graph") = false,
py::call_guard<py::gil_scoped_release>(),
R"(
Runs the backward pass using the RRef as the root of the
backward pass. If ``dist_autograd_ctx_id`` is provided,
we perform a distributed backward pass using the provided
ctx_id starting from the owner of the RRef. In this case,
:meth:`~torch.distributed.autograd.get_gradients` should be
used to retrieve the gradients. If ``dist_autograd_ctx_id``
is ``None``, it is assumed that this is a local autograd graph
and we only perform a local backward pass. In the local case,
the node calling this API has to be the owner of the RRef.
The value of the RRef is expected to be a scalar Tensor.
Args:
dist_autograd_ctx_id (int, optional): The distributed
autograd context id for which we should retrieve the
gradients (default: -1).
retain_graph(bool, optional): If ``False``, the graph used to
compute the grad will be freed. Note that in nearly all
cases setting this option to ``True`` is not needed and
often can be worked around in a much more efficient way.
Usually, you need to set this to ``True`` to run backward
multiple times (default: False).
Example::
>>> import torch.distributed.autograd as dist_autograd
>>> with dist_autograd.context() as context_id:
>>> rref.backward(context_id)
)")
// not releasing GIL to avoid context switch
.def("__repr__", &PyRRef::str);
#ifdef USE_TENSORPIPE
// Base class: torch.distributed.rpc.RpcBackendOptions.
py::class_<TensorPipeRpcBackendOptions>(
module, "_TensorPipeRpcBackendOptionsBase", rpcBackendOptions)
.def(
py::init<
int,
optional<std::vector<std::string>>,
optional<std::vector<std::string>>,
float,
std::string,
std::unordered_map<std::string, DeviceMap>,
std::vector<c10::Device>>(),
py::arg("num_worker_threads") = kDefaultNumWorkerThreads,
py::arg("_transports") = optional<std::vector<std::string>>(),
py::arg("_channels") = optional<std::vector<std::string>>(),
py::arg("rpc_timeout") = kDefaultRpcTimeoutSeconds,
py::arg("init_method") = kDefaultInitMethod,
py::arg("device_maps") = std::unordered_map<std::string, DeviceMap>(),
py::arg("devices") = std::vector<c10::Device>())
.def_readwrite(
"num_worker_threads",
&TensorPipeRpcBackendOptions::numWorkerThreads,
R"(
The number of threads in the thread-pool used by
:class:`~torch.distributed.rpc.TensorPipeAgent` to execute
requests.
)")
.def_readwrite(
"device_maps",
&TensorPipeRpcBackendOptions::deviceMaps,
R"(The device map locations.)")
.def_readwrite(
"devices",
&TensorPipeRpcBackendOptions::devices,
R"(All devices used by the local agent.)")
.def("_set_device_map", &TensorPipeRpcBackendOptions::setDeviceMap);
module.attr("_DEFAULT_NUM_WORKER_THREADS") =
py::cast(kDefaultNumWorkerThreads);
shared_ptr_class_<TensorPipeAgent>(module, "TensorPipeAgent", rpcAgent)
.def(
py::init(
[](const c10::intrusive_ptr<::c10d::Store>& store,
std::string selfName,
worker_id_t selfId,
optional<int> worldSize,
TensorPipeRpcBackendOptions opts,
std::unordered_map<std::string, DeviceMap> reverseDeviceMaps,
std::vector<c10::Device> devices) {
return std::shared_ptr<TensorPipeAgent>(
new TensorPipeAgent(
store,
std::move(selfName),
selfId,
worldSize,
std::move(opts),
std::move(reverseDeviceMaps),
std::move(devices),
std::make_unique<RequestCallbackImpl>()),
impl::destroy_without_gil<TensorPipeAgent>);
}),
py::arg("store"),
py::arg("name"),
py::arg("rank"),
py::arg("world_size"),
py::arg("rpc_backend_options"),
py::arg("reverse_device_maps"),
py::arg("devices"))
.def(
"join",
&TensorPipeAgent::join,
py::call_guard<py::gil_scoped_release>(),
py::arg("shutdown") = false,
py::arg("timeout") = 0)
.def(
"shutdown",
&TensorPipeAgent::shutdown,
py::call_guard<py::gil_scoped_release>())
.def(
"get_worker_info",
(const WorkerInfo& (TensorPipeAgent::*)(void) const) &
RpcAgent::getWorkerInfo,
py::call_guard<py::gil_scoped_release>())
.def(
"get_worker_info",
(const WorkerInfo& (TensorPipeAgent::*)(const std::string&) const) &
TensorPipeAgent::getWorkerInfo,
py::call_guard<py::gil_scoped_release>())
.def(
"get_worker_info",
(const WorkerInfo& (TensorPipeAgent::*)(worker_id_t id) const) &
TensorPipeAgent::getWorkerInfo,
py::call_guard<py::gil_scoped_release>())
.def(
"get_worker_infos",
(std::vector<WorkerInfo>(TensorPipeAgent::*)() const) &
TensorPipeAgent::getWorkerInfos,
py::call_guard<py::gil_scoped_release>())
.def(
"_get_device_map",
(DeviceMap(TensorPipeAgent::*)(const WorkerInfo& dst) const) &
TensorPipeAgent::getDeviceMap,
py::call_guard<py::gil_scoped_release>())
.def(
"_get_backend_options",
&TensorPipeAgent::getBackendOptions,
py::call_guard<py::gil_scoped_release>())
.def(
"_update_group_membership",
&TensorPipeAgent::updateGroupMembership,
py::call_guard<py::gil_scoped_release>())
.def_readonly("is_static_group", &TensorPipeAgent::isStaticGroup_)
.def_property_readonly("store", &TensorPipeAgent::getStore);
#endif // USE_TENSORPIPE
module.def("_is_current_rpc_agent_set", &RpcAgent::isCurrentRpcAgentSet);
module.def("_get_current_rpc_agent", &RpcAgent::getCurrentRpcAgent);
module.def(
"_set_and_start_rpc_agent",
[](const std::shared_ptr<RpcAgent>& rpcAgent) {
RpcAgent::setCurrentRpcAgent(rpcAgent);
// Initializing typeResolver inside RpcAgent constructor will make
// RpcAgent have python dependency. To avoid RpcAgent to have python
// dependency, setTypeResolver() here.
std::shared_ptr<TypeResolver> typeResolver =
std::make_shared<TypeResolver>([&](const c10::QualifiedName& qn) {
auto typePtr = PythonRpcHandler::getInstance().parseTypeFromStr(
qn.qualifiedName());
return c10::StrongTypePtr(
PythonRpcHandler::getInstance().jitCompilationUnit(),
std::move(typePtr));
});
rpcAgent->setTypeResolver(typeResolver);
rpcAgent->start();
},
py::call_guard<py::gil_scoped_release>());
module.def(
"_reset_current_rpc_agent",
[]() { RpcAgent::setCurrentRpcAgent(nullptr); },
py::call_guard<py::gil_scoped_release>());
module.def(
"_delete_all_user_and_unforked_owner_rrefs",
[](std::chrono::milliseconds timeoutMillis) {
RRefContext::getInstance().delAllUsersAndUnforkedOwners(timeoutMillis);
},
py::arg("timeout") = kDeleteAllUsersTimeout,
py::call_guard<py::gil_scoped_release>());
module.def("_destroy_rref_context", [](bool ignoreRRefLeak) {
// NB: do not release GIL in the function. The destroyInstance() method
// returns a list of deleted OwnerRRefs that hold py::object instances.
// Clearing those OwnerRRefs are likely to trigger Python deref, which
// requires GIL.
RRefContext::getInstance().destroyInstance(ignoreRRefLeak).clear();
});
module.def("_rref_context_get_debug_info", []() {
return RRefContext::getInstance().getDebugInfo();
});
module.def(
"_cleanup_python_rpc_handler",
[]() { PythonRpcHandler::getInstance().cleanup(); },
py::call_guard<py::gil_scoped_release>());
module.def(
"_invoke_rpc_builtin",
[](const WorkerInfo& dst,
const std::string& opName,
const float rpcTimeoutSeconds,
const py::args& args,
const py::kwargs& kwargs) {
return std::make_shared<jit::PythonFutureWrapper>(
pyRpcBuiltin(dst, opName, args, kwargs, rpcTimeoutSeconds));
},
py::call_guard<py::gil_scoped_acquire>());
module.def(
"_invoke_rpc_python_udf",
[](const WorkerInfo& dst,
std::string& pickledPythonUDF,
std::vector<torch::Tensor>& tensors,
const float rpcTimeoutSeconds,
const bool isAsyncExecution) {
return std::make_shared<jit::PythonFutureWrapper>(pyRpcPythonUdf(
dst,
pickledPythonUDF,
tensors,
rpcTimeoutSeconds,
isAsyncExecution));
},
py::call_guard<py::gil_scoped_release>());
module.def(
"_invoke_rpc_torchscript",
[](const std::string& dstWorkerName,
const std::string& qualifiedNameStr,
const py::tuple& argsTuple,
const py::dict& kwargsDict,
const float rpcTimeoutSeconds,
const bool isAsyncExecution) {
return std::make_shared<jit::PythonFutureWrapper>(pyRpcTorchscript(
dstWorkerName,
qualifiedNameStr,
argsTuple,
kwargsDict,
rpcTimeoutSeconds,
isAsyncExecution));
},
py::call_guard<py::gil_scoped_release>());
module.def(
"_invoke_remote_builtin",
&pyRemoteBuiltin,
py::call_guard<py::gil_scoped_acquire>());
module.def(
"_invoke_remote_python_udf",
&pyRemotePythonUdf,
py::call_guard<py::gil_scoped_release>());
module.def(
"_invoke_remote_torchscript",
&pyRemoteTorchscript,
py::call_guard<py::gil_scoped_release>());
module.def(
"get_rpc_timeout",
[]() {
return RpcAgent::getCurrentRpcAgent()->getRpcTimeout().count() /
kSecToMsConversion;
},
R"(
Retrieve the default timeout for all RPCs that was set during RPC initialization.
The returned value will be in seconds.
Returns:
``float`` indicating the RPC timeout in seconds.
)");
module.def(
"enable_gil_profiling",
[](bool flag) {
RpcAgent::getCurrentRpcAgent()->enableGILProfiling(flag);
},
R"(
Set whether GIL wait times should be enabled or not. This incurs a slight
overhead cost. Default is disabled for performance reasons.
Args:
flag (bool): True to set GIL profiling, False to disable.
)");
module.def(
"_set_rpc_timeout",
[](const float rpcTimeoutSeconds) {
auto rpcTimeout = std::chrono::milliseconds(
static_cast<int>(rpcTimeoutSeconds * kSecToMsConversion));
RpcAgent::getCurrentRpcAgent()->setRpcTimeout(rpcTimeout);
},
R"(
Set the default timeout for all RPCs. The input unit is expected to be
in seconds. If an RPC is not completed within this time, an exception
indicating it has timed out will be raised. To control timeout for
specific RPCs, a timeout parameter can be passed into
:meth:`~torch.distributed.rpc.rpc_sync` and
:meth:`~torch.distributed.rpc.rpc_async`.
Args:
rpcTimeoutSeconds (float): Timeout value in seconds.
)");
module.def(
"_enable_server_process_global_profiler",
&profiler::processglobal::enableServer);
module.def(
"_disable_server_process_global_profiler",
&profiler::processglobal::disableServer);
module.def("_set_profiler_node_id", &at::RecordFunction::setDefaultNodeId);
py::class_<
RemoteProfilerManager,
std::unique_ptr<RemoteProfilerManager, py::nodelete>>(
module, "RemoteProfilerManager")
.def("set_current_profiling_key", [](const std::string& key) {
auto& inst = RemoteProfilerManager::getInstance();
inst.setCurrentKey(key);
});
module.def(
"_enable_jit_rref_pickle",
&enableJitRRefPickle,
R"(
Allows ``torch.jit.save`` to save a ``torch.jit.ScriptModule`` with
pickled RRefs out of RPC contexts.
.. warning::
This is dangerous. If the module contains RRefs, the pickled
result must be sent over RPC and get unpickled on the receiving side
to restore the module. Otherwise, there will be RRef leaks, which
can potentially lead to program hang. When using this API, it is
applications responsibility to make sure that the above assumption
always holds.
)");
module.def("_disable_jit_rref_pickle", &disableJitRRefPickle);
Py_RETURN_TRUE;
}
} // namespace
static PyMethodDef methods[] = { // NOLINT
{"_rpc_init", rpc_init, METH_NOARGS, nullptr},
{nullptr, nullptr, 0, nullptr}};
PyMethodDef* python_functions() {
return methods;
}
} // namespace rpc
} // namespace distributed
} // namespace torch