-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy path__init__.py
100 lines (88 loc) · 5.1 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import asyncio, warnings, copy, time
class BaseNode:
def __init__(self): self.params,self.successors={},{}
def set_params(self,params): self.params=params
def add_successor(self,node,action="default"):
if action in self.successors: warnings.warn(f"Overwriting successor for action '{action}'")
self.successors[action]=node;return node
def prep(self,shared): pass
def exec(self,prep_res): pass
def post(self,shared,prep_res,exec_res): pass
def _exec(self,prep_res): return self.exec(prep_res)
def _run(self,shared): p=self.prep(shared);e=self._exec(p);return self.post(shared,p,e)
def run(self,shared):
if self.successors: warnings.warn("Node won't run successors. Use Flow.")
return self._run(shared)
def __rshift__(self,other): return self.add_successor(other)
def __sub__(self,action):
if isinstance(action,str): return _ConditionalTransition(self,action)
raise TypeError("Action must be a string")
class _ConditionalTransition:
def __init__(self,src,action): self.src,self.action=src,action
def __rshift__(self,tgt): return self.src.add_successor(tgt,self.action)
class Node(BaseNode):
def __init__(self,max_retries=1,wait=0): super().__init__();self.max_retries,self.wait=max_retries,wait
def exec_fallback(self,prep_res,exc): raise exc
def _exec(self,prep_res):
for self.cur_retry in range(self.max_retries):
try: return self.exec(prep_res)
except Exception as e:
if self.cur_retry==self.max_retries-1: return self.exec_fallback(prep_res,e)
if self.wait>0: time.sleep(self.wait)
class BatchNode(Node):
def _exec(self,items): return [super(BatchNode,self)._exec(i) for i in (items or [])]
class Flow(BaseNode):
def __init__(self,start): super().__init__();self.start=start
def get_next_node(self,curr,action):
nxt=curr.successors.get(action or "default")
if not nxt and curr.successors: warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}")
return nxt
def _orch(self,shared,params=None):
curr,p=copy.copy(self.start),(params or {**self.params})
while curr: curr.set_params(p);c=curr._run(shared);curr=copy.copy(self.get_next_node(curr,c))
def _run(self,shared): pr=self.prep(shared);self._orch(shared);return self.post(shared,pr,None)
def exec(self,prep_res): raise RuntimeError("Flow can't exec.")
class BatchFlow(Flow):
def _run(self,shared):
pr=self.prep(shared) or []
for bp in pr: self._orch(shared,{**self.params,**bp})
return self.post(shared,pr,None)
class AsyncNode(Node):
def prep(self,shared): raise RuntimeError("Use prep_async.")
def exec(self,prep_res): raise RuntimeError("Use exec_async.")
def post(self,shared,prep_res,exec_res): raise RuntimeError("Use post_async.")
def exec_fallback(self,prep_res,exc): raise RuntimeError("Use exec_fallback_async.")
def _run(self,shared): raise RuntimeError("Use run_async.")
async def prep_async(self,shared): pass
async def exec_async(self,prep_res): pass
async def exec_fallback_async(self,prep_res,exc): raise exc
async def post_async(self,shared,prep_res,exec_res): pass
async def _exec(self,prep_res):
for i in range(self.max_retries):
try: return await self.exec_async(prep_res)
except Exception as e:
if i==self.max_retries-1: return await self.exec_fallback_async(prep_res,e)
if self.wait>0: await asyncio.sleep(self.wait)
async def run_async(self,shared):
if self.successors: warnings.warn("Node won't run successors. Use AsyncFlow.")
return await self._run_async(shared)
async def _run_async(self,shared): p=await self.prep_async(shared);e=await self._exec(p);return await self.post_async(shared,p,e)
class AsyncBatchNode(AsyncNode,BatchNode):
async def _exec(self,items): return [await super(AsyncBatchNode,self)._exec(i) for i in items]
class AsyncParallelBatchNode(AsyncNode,BatchNode):
async def _exec(self,items): return await asyncio.gather(*(super(AsyncParallelBatchNode,self)._exec(i) for i in items))
class AsyncFlow(Flow,AsyncNode):
async def _orch_async(self,shared,params=None):
curr,p=copy.copy(self.start),(params or {**self.params})
while curr:curr.set_params(p);c=await curr._run_async(shared) if isinstance(curr,AsyncNode) else curr._run(shared);curr=copy.copy(self.get_next_node(curr,c))
async def _run_async(self,shared): p=await self.prep_async(shared);await self._orch_async(shared);return await self.post_async(shared,p,None)
class AsyncBatchFlow(AsyncFlow,BatchFlow):
async def _run_async(self,shared):
pr=await self.prep_async(shared) or []
for bp in pr: await self._orch_async(shared,{**self.params,**bp})
return await self.post_async(shared,pr,None)
class AsyncParallelBatchFlow(AsyncFlow,BatchFlow):
async def _run_async(self,shared):
pr=await self.prep_async(shared) or []
await asyncio.gather(*(self._orch_async(shared,{**self.params,**bp}) for bp in pr))
return await self.post_async(shared,pr,None)