From 940aeec0a2c039fdf599bfe05f1c12e157cd1753 Mon Sep 17 00:00:00 2001 From: Dobando <1692898084@qq.com> Date: Sat, 3 Aug 2024 21:55:59 +0800 Subject: [PATCH] feat: Implemented variable order field index --- casbin/async_enforcer.py | 11 ++-- casbin/async_internal_enforcer.py | 9 --- casbin/async_management_enforcer.py | 25 +++++-- casbin/enforcer.py | 22 ++++--- casbin/internal_enforcer.py | 9 --- casbin/management_enforcer.py | 15 ++++- casbin/model/assertion.py | 1 - casbin/model/model.py | 42 +++--------- casbin/model/policy.py | 65 ++++++++++++++++++- .../priority_model_explicit_customized.conf | 14 ++++ .../priority_policy_explicit_customized.csv | 12 ++++ tests/test_rbac_api.py | 36 ++++++++-- 12 files changed, 176 insertions(+), 85 deletions(-) create mode 100644 examples/priority_model_explicit_customized.conf create mode 100644 examples/priority_policy_explicit_customized.csv diff --git a/casbin/async_enforcer.py b/casbin/async_enforcer.py index f61322f7..9d7eedb3 100644 --- a/casbin/async_enforcer.py +++ b/casbin/async_enforcer.py @@ -16,6 +16,7 @@ from casbin.async_management_enforcer import AsyncManagementEnforcer from casbin.util import join_slice, array_remove_duplicates, set_subtract +from casbin.constant.constants import DOMAIN_INDEX, SUBJECT_INDEX, OBJECT_INDEX class AsyncEnforcer(AsyncManagementEnforcer): @@ -280,8 +281,8 @@ async def get_implicit_users_for_resource(self, resource): get_implicit_users_for_resource("data1") will return [[alice data1 read]] Note: only users will be returned, roles (2nd arg in "g") will be excluded.""" permissions = dict() - subject_index = await self.get_field_index("p", "sub") - object_index = await self.get_field_index("p", "obj") + subject_index = await self.get_field_index("p", SUBJECT_INDEX) + object_index = await self.get_field_index("p", OBJECT_INDEX) rm = self.get_role_manager() roles = self.get_all_roles() @@ -304,9 +305,9 @@ async def get_implicit_users_for_resource_by_domain(self, resource, domain): """get implicit user based on resource and domain. Compared to GetImplicitUsersForResource, domain is supported""" permissions = dict() - subject_index = await self.get_field_index("p", "sub") - object_index = await self.get_field_index("p", "obj") - dom_index = await self.get_field_index("p", "dom") + subject_index = await self.get_field_index("p", SUBJECT_INDEX) + object_index = await self.get_field_index("p", OBJECT_INDEX) + dom_index = await self.get_field_index("p", DOMAIN_INDEX) rm = self.get_role_manager() roles = await self.get_all_roles_by_domain(domain) diff --git a/casbin/async_internal_enforcer.py b/casbin/async_internal_enforcer.py index c1d0547d..cd8723ae 100644 --- a/casbin/async_internal_enforcer.py +++ b/casbin/async_internal_enforcer.py @@ -312,12 +312,3 @@ async def _remove_filtered_policy_returns_effects(self, sec, ptype, field_index, self.watcher.update() return rule_removed - - async def get_field_index(self, ptype, field): - """gets the index of the field name.""" - return self.model.get_field_index(ptype, field) - - async def set_field_index(self, ptype, field, index): - """sets the index of the field name.""" - assertion = self.model["p"][ptype] - assertion.field_index_map[field] = index diff --git a/casbin/async_management_enforcer.py b/casbin/async_management_enforcer.py index ef0a5ca8..7a9abd15 100644 --- a/casbin/async_management_enforcer.py +++ b/casbin/async_management_enforcer.py @@ -13,6 +13,7 @@ # limitations under the License. from casbin.async_internal_enforcer import AsyncInternalEnforcer from casbin.model.policy_op import PolicyOp +from casbin.constant.constants import ACTION_INDEX, SUBJECT_INDEX, OBJECT_INDEX class AsyncManagementEnforcer(AsyncInternalEnforcer): @@ -22,27 +23,30 @@ class AsyncManagementEnforcer(AsyncInternalEnforcer): def get_all_subjects(self): """gets the list of subjects that show up in the current policy.""" - return self.get_all_named_subjects("p") + return self.model.get_values_for_field_in_policy_all_types_by_name("p", SUBJECT_INDEX) def get_all_named_subjects(self, ptype): """gets the list of subjects that show up in the current named policy.""" - return self.model.get_values_for_field_in_policy("p", ptype, 0) + field_index = self.model.get_field_index(ptype, SUBJECT_INDEX) + return self.model.get_values_for_field_in_policy("p", ptype, field_index) def get_all_objects(self): """gets the list of objects that show up in the current policy.""" - return self.get_all_named_objects("p") + return self.model.get_values_for_field_in_policy_all_types_by_name("p", OBJECT_INDEX) def get_all_named_objects(self, ptype): """gets the list of objects that show up in the current named policy.""" - return self.model.get_values_for_field_in_policy("p", ptype, 1) + field_index = self.model.get_field_index(ptype, OBJECT_INDEX) + return self.model.get_values_for_field_in_policy("p", ptype, field_index) def get_all_actions(self): """gets the list of actions that show up in the current policy.""" - return self.get_all_named_actions("p") + return self.model.get_values_for_field_in_policy_all_types_by_name("p", ACTION_INDEX) def get_all_named_actions(self, ptype): """gets the list of actions that show up in the current named policy.""" - return self.model.get_values_for_field_in_policy("p", ptype, 2) + field_index = self.model.get_field_index(ptype, ACTION_INDEX) + return self.model.get_values_for_field_in_policy("p", ptype, field_index) def get_all_roles(self): """gets the list of roles that show up in the current named policy.""" @@ -302,3 +306,12 @@ async def remove_filtered_named_grouping_policy(self, ptype, field_index, *field def add_function(self, name, func): """adds a customized function.""" self.fm.add_function(name, func) + + async def get_field_index(self, ptype, field): + """gets the index of the field name.""" + return self.model.get_field_index(ptype, field) + + async def set_field_index(self, ptype, field, index): + """sets the index of the field name.""" + assertion = self.model["p"][ptype] + assertion.field_index_map[field] = index diff --git a/casbin/enforcer.py b/casbin/enforcer.py index a1777f9c..b994426d 100644 --- a/casbin/enforcer.py +++ b/casbin/enforcer.py @@ -16,6 +16,7 @@ from casbin.management_enforcer import ManagementEnforcer from casbin.util import join_slice, array_remove_duplicates, set_subtract +from casbin.constant.constants import DOMAIN_INDEX, SUBJECT_INDEX, OBJECT_INDEX class Enforcer(ManagementEnforcer): @@ -73,7 +74,8 @@ def delete_user(self, user): """ res1 = self.remove_filtered_grouping_policy(0, user) - res2 = self.remove_filtered_policy(0, user) + sub_index = self.get_field_index("p", SUBJECT_INDEX) + res2 = self.remove_filtered_policy(sub_index, user) return res1 or res2 def delete_role(self, role): @@ -83,7 +85,8 @@ def delete_role(self, role): """ res1 = self.remove_filtered_grouping_policy(1, role) - res2 = self.remove_filtered_policy(0, role) + sub_index = self.get_field_index("p", SUBJECT_INDEX) + res2 = self.remove_filtered_policy(sub_index, role) return res1 or res2 def delete_permission(self, *permission): @@ -112,7 +115,10 @@ def delete_permissions_for_user(self, user): deletes permissions for a user or role. Returns false if the user or role does not have any permissions (aka not affected). """ - return self.remove_filtered_policy(0, user) + sub_index = self.get_field_index("p", SUBJECT_INDEX) + if sub_index == -1: + return False + return self.remove_filtered_policy(sub_index, user) def get_permissions_for_user(self, user): """ @@ -289,8 +295,8 @@ def get_implicit_users_for_resource(self, resource): get_implicit_users_for_resource("data1") will return [[alice data1 read]] Note: only users will be returned, roles (2nd arg in "g") will be excluded.""" permissions = dict() - subject_index = self.get_field_index("p", "sub") - object_index = self.get_field_index("p", "obj") + subject_index = self.get_field_index("p", SUBJECT_INDEX) + object_index = self.get_field_index("p", OBJECT_INDEX) rm = self.get_role_manager() roles = self.get_all_roles() @@ -313,9 +319,9 @@ def get_implicit_users_for_resource_by_domain(self, resource, domain): """get implicit user based on resource and domain. Compared to GetImplicitUsersForResource, domain is supported""" permissions = dict() - subject_index = self.get_field_index("p", "sub") - object_index = self.get_field_index("p", "obj") - dom_index = self.get_field_index("p", "dom") + subject_index = self.get_field_index("p", SUBJECT_INDEX) + object_index = self.get_field_index("p", OBJECT_INDEX) + dom_index = self.get_field_index("p", DOMAIN_INDEX) rm = self.get_role_manager() roles = self.get_all_roles_by_domain(domain) diff --git a/casbin/internal_enforcer.py b/casbin/internal_enforcer.py index fe62071e..96c196ed 100644 --- a/casbin/internal_enforcer.py +++ b/casbin/internal_enforcer.py @@ -187,12 +187,3 @@ def _remove_filtered_policy_returns_effects(self, sec, ptype, field_index, *fiel self.watcher.update() return rule_removed - - def get_field_index(self, ptype, field): - """gets the index of the field name.""" - return self.model.get_field_index(ptype, field) - - def set_field_index(self, ptype, field, index): - """sets the index of the field name.""" - assertion = self.model["p"][ptype] - assertion.field_index_map[field] = index diff --git a/casbin/management_enforcer.py b/casbin/management_enforcer.py index b1897fad..955ec984 100644 --- a/casbin/management_enforcer.py +++ b/casbin/management_enforcer.py @@ -24,7 +24,7 @@ class ManagementEnforcer(InternalEnforcer): def get_all_subjects(self): """gets the list of subjects that show up in the current policy.""" - return self.get_all_named_subjects("p") + return self.model.get_values_for_field_in_policy_all_types_by_name("p", SUBJECT_INDEX) def get_all_named_subjects(self, ptype): """gets the list of subjects that show up in the current named policy.""" @@ -33,7 +33,7 @@ def get_all_named_subjects(self, ptype): def get_all_objects(self): """gets the list of objects that show up in the current policy.""" - return self.get_all_named_objects("p") + return self.model.get_values_for_field_in_policy_all_types_by_name("p", OBJECT_INDEX) def get_all_named_objects(self, ptype): """gets the list of objects that show up in the current named policy.""" @@ -42,7 +42,7 @@ def get_all_named_objects(self, ptype): def get_all_actions(self): """gets the list of actions that show up in the current policy.""" - return self.get_all_named_actions("p") + return self.model.get_values_for_field_in_policy_all_types_by_name("p", ACTION_INDEX) def get_all_named_actions(self, ptype): """gets the list of actions that show up in the current named policy.""" @@ -309,3 +309,12 @@ def remove_filtered_named_grouping_policy(self, ptype, field_index, *field_value def add_function(self, name, func): """adds a customized function.""" self.fm.add_function(name, func) + + def get_field_index(self, ptype, field): + """gets the index of the field name.""" + return self.model.get_field_index(ptype, field) + + def set_field_index(self, ptype, field, index): + """sets the index of the field name.""" + assertion = self.model["p"][ptype] + assertion.field_index_map[field] = index diff --git a/casbin/model/assertion.py b/casbin/model/assertion.py index 120ce344..9b6f47a4 100644 --- a/casbin/model/assertion.py +++ b/casbin/model/assertion.py @@ -27,7 +27,6 @@ def __init__(self): self.policy = [] self.rm = None self.cond_rm = None - self.priority_index: int = -1 self.policy_map: dict = {} self.field_index_map: dict = {} diff --git a/casbin/model/model.py b/casbin/model/model.py index ecef7753..e29f1998 100644 --- a/casbin/model/model.py +++ b/casbin/model/model.py @@ -16,6 +16,7 @@ from casbin import util, config from . import Assertion from .policy import Policy +from casbin.constant.constants import DOMAIN_INDEX, PRIORITY_INDEX, SUBJECT_PRIORITY_EFFECT DEFAULT_DOMAIN = "" DEFAULT_SEPARATOR = "::" @@ -116,19 +117,16 @@ def print_model(self): def sort_policies_by_priority(self): for ptype, assertion in self["p"].items(): - for index, token in enumerate(assertion.tokens): - if token == f"{ptype}_priority": - assertion.priority_index = index - break + priority_index = self.get_field_index(ptype, PRIORITY_INDEX) - if assertion.priority_index == -1: + if priority_index == -1: continue assertion.policy = sorted( assertion.policy, - key=lambda x: int(x[assertion.priority_index]) - if x[assertion.priority_index].isdigit() - else x[assertion.priority_index], + key=lambda x: int(x[priority_index]) + if x[priority_index].isdigit() + else x[priority_index], ) for i, policy in enumerate(assertion.policy): @@ -137,16 +135,12 @@ def sort_policies_by_priority(self): return None def sort_policies_by_subject_hierarchy(self): - if self["e"]["e"].value != "subjectPriority(p_eft) || deny": + if self["e"]["e"].value != SUBJECT_PRIORITY_EFFECT: return sub_index = 0 - domain_index = -1 for ptype, assertion in self["p"].items(): - for index, token in enumerate(assertion.tokens): - if token == "{}_dom".format(ptype): - domain_index = index - break + domain_index = self.get_field_index(ptype, DOMAIN_INDEX) subject_hierarchy_map = self.get_subject_hierarchy_map(self["g"]["g"].policy) @@ -230,23 +224,3 @@ def write_string(sec): s[-1] = s[-1].strip() return "".join(s) - - def get_field_index(self, ptype, field): - """get_field_index gets the index of the field for a ptype in a policy, - return -1 if the field does not exist.""" - assertion = self["p"][ptype] - if field in assertion.field_index_map: - return assertion.field_index_map[field] - - pattern = f"{ptype}_{field}" - index = -1 - for i, token in enumerate(assertion.tokens): - if token == pattern: - index = i - break - - if index == -1: - return index - - assertion.field_index_map[field] = index - return index diff --git a/casbin/model/policy.py b/casbin/model/policy.py index 01a4d8a9..57328914 100644 --- a/casbin/model/policy.py +++ b/casbin/model/policy.py @@ -13,6 +13,8 @@ # limitations under the License. import logging +from casbin.util import util +from casbin.constant.constants import PRIORITY_INDEX DEFAULT_SEP = "," @@ -119,14 +121,17 @@ def add_policy(self, sec, ptype, rule): else: return False - if sec == "p" and assertion.priority_index >= 0: + has_priority = False + if assertion.field_index_map.get(PRIORITY_INDEX) is not None: + has_priority = True + if sec == "p" and has_priority: try: - idx_insert = int(rule[assertion.priority_index]) + idx_insert = int(rule[assertion.field_index_map[PRIORITY_INDEX]]) i = len(assertion.policy) - 1 for i in range(i, 0, -1): try: - idx = int(assertion.policy[i - 1][assertion.priority_index]) + idx = int(assertion.policy[i - 1][assertion.field_index_map[PRIORITY_INDEX]]) except Exception as e: print(e) @@ -303,3 +308,57 @@ def get_values_for_field_in_policy(self, sec, ptype, field_index): values.append(value) return values + + def get_values_for_field_in_policy_all_types(self, sec, field_index): + """gets all values for a field for all rules in a policy of all ptypes, duplicated values are removed.""" + values = [] + if sec not in self.keys(): + return values + + for ptype in self[sec]: + value = self.get_values_for_field_in_policy(sec, ptype, field_index) + values.extend(value) + + values = util.array_remove_duplicates(values) + + return values + + def get_values_for_field_in_policy_all_types_by_name(self, sec, field): + """gets all values for a field for all rules in a policy of all ptypes, duplicated values are removed.""" + values = [] + if sec not in self.keys(): + return values + + for ptype in self[sec]: + index = self.get_field_index(ptype, field) + value = self.get_values_for_field_in_policy(sec, ptype, index) + values.extend(value) + + values = util.array_remove_duplicates(values) + + return values + + def get_field_index(self, ptype, field): + """get_field_index gets the index of the field for a ptype in a policy, + return -1 if the field does not exist.""" + assertion = self["p"][ptype] + if field in assertion.field_index_map: + return assertion.field_index_map[field] + + pattern = f"{ptype}_{field}" + index = -1 + for i, token in enumerate(assertion.tokens): + if token == pattern: + index = i + break + + if index == -1: + return index + + assertion.field_index_map[field] = index + return index + + def set_field_index(self, ptype, field, index): + """sets the index of the field name.""" + assertion = self["p"][ptype] + assertion.field_index_map[field] = index diff --git a/examples/priority_model_explicit_customized.conf b/examples/priority_model_explicit_customized.conf new file mode 100644 index 00000000..5071fa77 --- /dev/null +++ b/examples/priority_model_explicit_customized.conf @@ -0,0 +1,14 @@ +[request_definition] +r = subject, obj, act + +[policy_definition] +p = customized_priority, obj, act, eft, subject + +[role_definition] +g = _, _ + +[policy_effect] +e = priority(p.eft) || deny + +[matchers] +m = g(r.subject, p.subject) && r.obj == p.obj && r.act == p.act \ No newline at end of file diff --git a/examples/priority_policy_explicit_customized.csv b/examples/priority_policy_explicit_customized.csv new file mode 100644 index 00000000..a861e2ba --- /dev/null +++ b/examples/priority_policy_explicit_customized.csv @@ -0,0 +1,12 @@ +p, 10, data1, read, deny, data1_deny_group +p, 10, data1, write, deny, data1_deny_group +p, 10, data2, read, allow, data2_allow_group +p, 10, data2, write, allow, data2_allow_group + + +p, 1, data1, write, allow, alice +p, 1, data1, read, allow, alice +p, 1, data2, read, deny, bob + +g, bob, data2_allow_group +g, alice, data1_deny_group diff --git a/tests/test_rbac_api.py b/tests/test_rbac_api.py index 6171feda..b5987ce3 100644 --- a/tests/test_rbac_api.py +++ b/tests/test_rbac_api.py @@ -14,7 +14,7 @@ from unittest import IsolatedAsyncioTestCase import casbin -from casbin.constant.constants import DOMAIN_INDEX +from casbin.constant.constants import DOMAIN_INDEX, SUBJECT_INDEX, OBJECT_INDEX, PRIORITY_INDEX, ACTION_INDEX from tests.test_enforcer import get_examples, TestCaseBase @@ -468,14 +468,36 @@ def test_domain_match_model(self): self.assertTrue(e.enforce("bob", "domain2", "data2", "read")) self.assertTrue(e.enforce("bob", "domain2", "data2", "write")) - def test_set_field_index(self): + def test_customized_field_index(self): e = self.get_enforcer( - get_examples("rbac_with_domains_model.conf"), - get_examples("rbac_with_domains_policy.csv"), + get_examples("priority_model_explicit_customized.conf"), + get_examples("priority_policy_explicit_customized.csv"), ) - self.assertEqual(e.get_field_index("p", DOMAIN_INDEX), 1) - e.set_field_index("p", DOMAIN_INDEX, 2) - self.assertEqual(e.get_field_index("p", DOMAIN_INDEX), 2) + + self.assertEqual(0, e.get_field_index("p", "customized_priority")) + self.assertEqual(1, e.get_field_index("p", OBJECT_INDEX)) + self.assertEqual(2, e.get_field_index("p", ACTION_INDEX)) + self.assertEqual(3, e.get_field_index("p", "eft")) + self.assertEqual(4, e.get_field_index("p", "subject")) + + self.assertTrue(e.enforce("bob", "data2", "read")) + e.set_field_index("p", PRIORITY_INDEX, 0) + e.load_policy() + self.assertFalse(e.enforce("bob", "data2", "read")) + + self.assertTrue(e.enforce("bob", "data2", "write")) + e.add_policy("1", "data2", "write", "deny", "bob") + self.assertFalse(e.enforce("bob", "data2", "write")) + + self.assertFalse(e.delete_permissions_for_user("bob")) + + e.set_field_index("p", SUBJECT_INDEX, 4) + + self.assertTrue(e.delete_permissions_for_user("bob")) + self.assertTrue(e.enforce("bob", "data2", "write")) + + self.assertTrue(e.delete_role("data2_allow_group")) + self.assertFalse(e.enforce("bob", "data2", "write")) class TestRbacApiSynced(TestRbacApi):