-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathadapter.py
More file actions
346 lines (283 loc) · 11.2 KB
/
adapter.py
File metadata and controls
346 lines (283 loc) · 11.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
from contextlib import contextmanager
import sqlalchemy
from casbin import persist
from sqlalchemy import Column, Integer, String
from sqlalchemy import create_engine, or_
from sqlalchemy.orm import sessionmaker
# declarative base class
if sqlalchemy.__version__.startswith("1."):
from sqlalchemy.orm import declarative_base
Base = declarative_base()
else:
from sqlalchemy.orm import DeclarativeBase
class Base(DeclarativeBase):
pass
# Cache for CasbinRule classes by table name to avoid duplicate class warnings
_casbin_rule_cache = {}
def create_casbin_rule_class(table_name):
"""
Factory function to create a CasbinRule class with a custom table name.
Args:
table_name (str): Table name for the CasbinRule class.
Returns:
db_class (CasbinRule): The CasbinRule class.
"""
# Return cached class if it exists for this table name
if table_name in _casbin_rule_cache:
return _casbin_rule_cache[table_name]
# Create a unique class name based on the table name to avoid SQLAlchemy warnings
# Convert table_name to a valid Python class name
class_name = "CasbinRule_" + "".join(c if c.isalnum() else "_" for c in table_name)
# Dynamically create the class with a unique name
CasbinRule = type(
class_name,
(Base,),
{
"__tablename__": table_name,
"__table_args__": {"extend_existing": True},
"id": Column(Integer, primary_key=True),
"ptype": Column(String(255)),
"v0": Column(String(255)),
"v1": Column(String(255)),
"v2": Column(String(255)),
"v3": Column(String(255)),
"v4": Column(String(255)),
"v5": Column(String(255)),
"__str__": lambda self: ", ".join(
[self.ptype]
+ [
v
for v in (self.v0, self.v1, self.v2, self.v3, self.v4, self.v5)
if v is not None
]
),
"__repr__": lambda self: '<CasbinRule {}: "{}">'.format(self.id, str(self)),
"__module__": "sqlalchemy_adapter.adapter",
},
)
# Cache the class before returning
_casbin_rule_cache[table_name] = CasbinRule
return CasbinRule
# Export the default CasbinRule class with table name 'casbin_rule'.
CasbinRule = create_casbin_rule_class("casbin_rule")
class Filter:
ptype = []
v0 = []
v1 = []
v2 = []
v3 = []
v4 = []
v5 = []
class Adapter(persist.Adapter, persist.adapters.UpdateAdapter):
"""the interface for Casbin adapters."""
def __init__(
self,
engine,
db_class=None,
table_name="casbin_rule",
filtered=False,
create_table=True,
):
if isinstance(engine, str):
self._engine = create_engine(engine)
else:
self._engine = engine
if db_class is None:
db_class = create_casbin_rule_class(table_name=table_name)
metadata = Base.metadata
else:
for attr in (
"id",
"ptype",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
): # id attr was used by filter
if not hasattr(db_class, attr):
raise Exception(f"{attr} not found in custom DatabaseClass.")
metadata = db_class.metadata
self._db_class = db_class
self.session_local = sessionmaker(bind=self._engine)
if create_table:
metadata.create_all(self._engine)
self._filtered = filtered
@contextmanager
def _session_scope(self):
"""Provide a transactional scope around a series of operations."""
session = self.session_local()
try:
yield session
session.commit()
except Exception as e:
session.rollback()
raise e
finally:
session.close()
def load_policy(self, model):
"""loads all policy rules from the storage."""
with self._session_scope() as session:
lines = session.query(self._db_class).all()
for line in lines:
persist.load_policy_line(str(line), model)
def is_filtered(self):
return self._filtered
def load_filtered_policy(self, model, filter) -> None:
"""loads all policy rules from the storage."""
with self._session_scope() as session:
query = session.query(self._db_class)
filters = self.filter_query(query, filter)
filters = filters.all()
for line in filters:
persist.load_policy_line(str(line), model)
self._filtered = True
def filter_query(self, querydb, filter):
for attr in ("ptype", "v0", "v1", "v2", "v3", "v4", "v5"):
if len(getattr(filter, attr)) > 0:
querydb = querydb.filter(
getattr(self._db_class, attr).in_(getattr(filter, attr))
)
return querydb.order_by(self._db_class.id)
def _save_policy_line(self, ptype, rule, session=None):
line = self._db_class(ptype=ptype)
for i, v in enumerate(rule):
setattr(line, "v{}".format(i), v)
if session:
session.add(line)
else:
with self._session_scope() as session:
session.add(line)
def save_policy(self, model):
"""saves all policy rules to the storage."""
with self._session_scope() as session:
query = session.query(self._db_class)
query.delete()
for sec in ["p", "g"]:
if sec not in model.model.keys():
continue
for ptype, ast in model.model[sec].items():
for rule in ast.policy:
self._save_policy_line(ptype, rule, session=session)
return True
def add_policy(self, sec, ptype, rule):
"""adds a policy rule to the storage."""
self._save_policy_line(ptype, rule)
def add_policies(self, sec, ptype, rules):
"""adds a policy rules to the storage."""
for rule in rules:
self._save_policy_line(ptype, rule)
def remove_policy(self, sec, ptype, rule):
"""removes a policy rule from the storage."""
with self._session_scope() as session:
query = session.query(self._db_class)
query = query.filter(self._db_class.ptype == ptype)
for i, v in enumerate(rule):
query = query.filter(getattr(self._db_class, "v{}".format(i)) == v)
r = query.delete()
return True if r > 0 else False
def remove_policies(self, sec, ptype, rules):
"""remove policy rules from the storage."""
if not rules:
return
with self._session_scope() as session:
query = session.query(self._db_class)
query = query.filter(self._db_class.ptype == ptype)
rules = zip(*rules)
for i, rule in enumerate(rules):
query = query.filter(
or_(getattr(self._db_class, "v{}".format(i)) == v for v in rule)
)
query.delete()
def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
"""removes policy rules that match the filter from the storage.
This is part of the Auto-Save feature.
"""
with self._session_scope() as session:
query = session.query(self._db_class).filter(self._db_class.ptype == ptype)
if not (0 <= field_index <= 5):
return False
if not (1 <= field_index + len(field_values) <= 6):
return False
for i, v in enumerate(field_values):
if v != "":
v_value = getattr(self._db_class, "v{}".format(field_index + i))
query = query.filter(v_value == v)
r = query.delete()
return True if r > 0 else False
def update_policy(
self, sec: str, ptype: str, old_rule: [str], new_rule: [str]
) -> None:
"""
Update the old_rule with the new_rule in the database (storage).
:param sec: section type
:param ptype: policy type
:param old_rule: the old rule that needs to be modified
:param new_rule: the new rule to replace the old rule
:return: None
"""
with self._session_scope() as session:
query = session.query(self._db_class).filter(self._db_class.ptype == ptype)
# locate the old rule
for index, value in enumerate(old_rule):
v_value = getattr(self._db_class, "v{}".format(index))
query = query.filter(v_value == value)
# need the length of the longest_rule to perform overwrite
longest_rule = old_rule if len(old_rule) > len(new_rule) else new_rule
old_rule_line = query.one()
# overwrite the old rule with the new rule
for index in range(len(longest_rule)):
if index < len(new_rule):
exec(f"old_rule_line.v{index} = new_rule[{index}]")
else:
exec(f"old_rule_line.v{index} = None")
def update_policies(
self,
sec: str,
ptype: str,
old_rules: [
[str],
],
new_rules: [
[str],
],
) -> None:
"""
Update the old_rules with the new_rules in the database (storage).
:param sec: section type
:param ptype: policy type
:param old_rules: the old rules that need to be modified
:param new_rules: the new rules to replace the old rules
:return: None
"""
for i in range(len(old_rules)):
self.update_policy(sec, ptype, old_rules[i], new_rules[i])
def update_filtered_policies(
self, sec, ptype, new_rules: [[str]], field_index, *field_values
) -> [[str]]:
"""update_filtered_policies updates all the policies on the basis of the filter."""
filter = Filter()
filter.ptype = ptype
# Creating Filter from the field_index & field_values provided
for i in range(len(field_values)):
if field_index <= i and i < field_index + len(field_values):
setattr(filter, f"v{i}", field_values[i - field_index])
else:
break
self._update_filtered_policies(new_rules, filter)
def _update_filtered_policies(self, new_rules, filter) -> [[str]]:
"""_update_filtered_policies updates all the policies on the basis of the filter."""
with self._session_scope() as session:
# Load old policies
query = session.query(self._db_class).filter(
self._db_class.ptype == filter.ptype
)
filtered_query = self.filter_query(query, filter)
old_rules = filtered_query.all()
# Delete old policies
self.remove_policies("p", filter.ptype, old_rules)
# Insert new policies
self.add_policies("p", filter.ptype, new_rules)
# return deleted rules
return old_rules