From 46ba668bc490e74eb29b53162e973ed0719653c6 Mon Sep 17 00:00:00 2001 From: Victor Bayim <2506976721@qq.com> Date: Thu, 20 Mar 2025 23:23:42 +0800 Subject: [PATCH 1/3] fix(policy): synchronize policy_map updates in add, update, and remove operations --- casbin/model/policy.py | 89 +++++++++++++++++++++++++----------------- 1 file changed, 53 insertions(+), 36 deletions(-) diff --git a/casbin/model/policy.py b/casbin/model/policy.py index 5daeea3..a693ba7 100644 --- a/casbin/model/policy.py +++ b/casbin/model/policy.py @@ -39,10 +39,8 @@ def items(self): def build_role_links(self, rm_map): """initializes the roles in RBAC.""" - if "g" not in self.keys(): return - for ptype, ast in self["g"].items(): rm = rm_map.get(ptype) if rm: @@ -68,28 +66,23 @@ def build_conditional_role_links(self, cond_rm_map): def print_policy(self): """Log using info""" - self.logger.info("Policy:") for sec in ["p", "g"]: if sec not in self.keys(): continue - for key, ast in self[sec].items(): self.logger.info("{} : {} : {}".format(key, ast.value, ast.policy)) def clear_policy(self): """clears all current policy.""" - for sec in ["p", "g"]: if sec not in self.keys(): continue - for key in self[sec].keys(): self[sec][key].policy = [] def get_policy(self, sec, ptype): """gets all rules in a policy.""" - return self[sec][ptype].policy def get_filtered_policy(self, sec, ptype, field_index, *field_values): @@ -109,7 +102,6 @@ def has_policy(self, sec, ptype, rule): return False if ptype not in self[sec]: return False - return rule in self[sec][ptype].policy def add_policy(self, sec, ptype, rule): @@ -123,23 +115,19 @@ def add_policy(self, sec, ptype, rule): if sec == "p" and assertion.priority_index >= 0: try: idx_insert = int(rule[assertion.priority_index]) - i = len(assertion.policy) - 1 for i in range(i, 0, -1): try: idx = int(assertion.policy[i - 1][assertion.priority_index]) except Exception as e: print(e) - if idx > idx_insert: tmp = assertion.policy[i] assertion.policy[i] = assertion.policy[i - 1] assertion.policy[i - 1] = tmp else: break - assertion.policy_map[DEFAULT_SEP.join(rule)] = i - except Exception as e: print(e) @@ -148,19 +136,16 @@ def add_policy(self, sec, ptype, rule): def add_policies(self, sec, ptype, rules): """adds policy rules to the model.""" - for rule in rules: if self.has_policy(sec, ptype, rule): return False - for rule in rules: - self[sec][ptype].policy.append(rule) - + if not self.add_policy(sec, ptype, rule): + return False return True def update_policy(self, sec, ptype, old_rule, new_rule): """update a policy rule from the model.""" - if sec not in self.keys(): return False if ptype not in self[sec]: @@ -175,18 +160,21 @@ def update_policy(self, sec, ptype, old_rule, new_rule): if "p_priority" in ast.tokens: priority_index = ast.tokens.index("p_priority") - if old_rule[priority_index] == new_rule[priority_index]: - ast.policy[rule_index] = new_rule - else: + if old_rule[priority_index] != new_rule[priority_index]: raise Exception("New rule should have the same priority with old rule.") - else: - ast.policy[rule_index] = new_rule + # 替换列表中的规则 + ast.policy[rule_index] = new_rule + # 更新映射:删除旧键,添加新键 + old_key = DEFAULT_SEP.join(old_rule) + new_key = DEFAULT_SEP.join(new_rule) + if old_key in ast.policy_map: + del ast.policy_map[old_key] + ast.policy_map[new_key] = rule_index return True def update_policies(self, sec, ptype, old_rules, new_rules): """update policy rules from the model.""" - if sec not in self.keys(): return False if ptype not in self[sec]: @@ -206,13 +194,22 @@ def update_policies(self, sec, ptype, old_rules, new_rules): if "p_priority" in ast.tokens: priority_index = ast.tokens.index("p_priority") for idx, old_rule, new_rule in zip(old_rules_index, old_rules, new_rules): - if old_rule[priority_index] == new_rule[priority_index]: - ast.policy[idx] = new_rule - else: + if old_rule[priority_index] != new_rule[priority_index]: raise Exception("New rule should have the same priority with old rule.") + ast.policy[idx] = new_rule + old_key = DEFAULT_SEP.join(old_rule) + new_key = DEFAULT_SEP.join(new_rule) + if old_key in ast.policy_map: + del ast.policy_map[old_key] + ast.policy_map[new_key] = idx else: for idx, old_rule, new_rule in zip(old_rules_index, old_rules, new_rules): ast.policy[idx] = new_rule + old_key = DEFAULT_SEP.join(old_rule) + new_key = DEFAULT_SEP.join(new_rule) + if old_key in ast.policy_map: + del ast.policy_map[old_key] + ast.policy_map[new_key] = idx return True @@ -221,19 +218,30 @@ def remove_policy(self, sec, ptype, rule): if not self.has_policy(sec, ptype, rule): return False - self[sec][ptype].policy.remove(rule) + assertion = self[sec][ptype] + assertion.policy.remove(rule) + # 重新构建映射 + new_map = {} + for idx, r in enumerate(assertion.policy): + new_map[DEFAULT_SEP.join(r)] = idx + assertion.policy_map = new_map - return rule not in self[sec][ptype].policy + return rule not in assertion.policy def remove_policies(self, sec, ptype, rules): """RemovePolicies removes policy rules from the model.""" - + assertion = self[sec][ptype] for rule in rules: if not self.has_policy(sec, ptype, rule): return False - self[sec][ptype].policy.remove(rule) - if rule in self[sec][ptype].policy: + assertion.policy.remove(rule) + if rule in assertion.policy: return False + # 重新构建映射 + new_map = {} + for idx, r in enumerate(assertion.policy): + new_map[DEFAULT_SEP.join(r)] = idx + assertion.policy_map = new_map return True @@ -243,7 +251,6 @@ def remove_policies_with_effected(self, sec, ptype, rules): if self.has_policy(sec, ptype, rule): effected.append(rule) self.remove_policy(sec, ptype, rule) - return effected def remove_filtered_policy_returns_effects(self, sec, ptype, field_index, *field_values): @@ -266,7 +273,13 @@ def remove_filtered_policy_returns_effects(self, sec, ptype, field_index, *field else: tmp.append(rule) - self[sec][ptype].policy = tmp + assertion = self[sec][ptype] + assertion.policy = tmp + # 重新构建映射 + new_map = {} + for idx, r in enumerate(assertion.policy): + new_map[DEFAULT_SEP.join(r)] = idx + assertion.policy_map = new_map return effects @@ -286,7 +299,13 @@ def remove_filtered_policy(self, sec, ptype, field_index, *field_values): else: tmp.append(rule) - self[sec][ptype].policy = tmp + assertion = self[sec][ptype] + assertion.policy = tmp + # 重新构建映射 + new_map = {} + for idx, r in enumerate(assertion.policy): + new_map[DEFAULT_SEP.join(r)] = idx + assertion.policy_map = new_map return res @@ -297,10 +316,8 @@ def get_values_for_field_in_policy(self, sec, ptype, field_index): return values if ptype not in self[sec]: return values - for rule in self[sec][ptype].policy: value = rule[field_index] if value not in values: values.append(value) - return values From 0febd1dd9bd9ae56609a33a8b870906026c24002 Mon Sep 17 00:00:00 2001 From: Victor Bayim <2506976721@qq.com> Date: Sat, 22 Mar 2025 01:16:57 +0800 Subject: [PATCH 2/3] fix(policy): synchronize policy_map updates in update and remove operations --- casbin/model/policy.py | 89 +++++++++++++++++++++--------------------- 1 file changed, 44 insertions(+), 45 deletions(-) diff --git a/casbin/model/policy.py b/casbin/model/policy.py index a693ba7..b9f3933 100644 --- a/casbin/model/policy.py +++ b/casbin/model/policy.py @@ -39,8 +39,10 @@ def items(self): def build_role_links(self, rm_map): """initializes the roles in RBAC.""" + if "g" not in self.keys(): return + for ptype, ast in self["g"].items(): rm = rm_map.get(ptype) if rm: @@ -66,23 +68,28 @@ def build_conditional_role_links(self, cond_rm_map): def print_policy(self): """Log using info""" + self.logger.info("Policy:") for sec in ["p", "g"]: if sec not in self.keys(): continue + for key, ast in self[sec].items(): self.logger.info("{} : {} : {}".format(key, ast.value, ast.policy)) def clear_policy(self): """clears all current policy.""" + for sec in ["p", "g"]: if sec not in self.keys(): continue + for key in self[sec].keys(): self[sec][key].policy = [] def get_policy(self, sec, ptype): """gets all rules in a policy.""" + return self[sec][ptype].policy def get_filtered_policy(self, sec, ptype, field_index, *field_values): @@ -102,6 +109,7 @@ def has_policy(self, sec, ptype, rule): return False if ptype not in self[sec]: return False + return rule in self[sec][ptype].policy def add_policy(self, sec, ptype, rule): @@ -115,19 +123,23 @@ def add_policy(self, sec, ptype, rule): if sec == "p" and assertion.priority_index >= 0: try: idx_insert = int(rule[assertion.priority_index]) + i = len(assertion.policy) - 1 for i in range(i, 0, -1): try: idx = int(assertion.policy[i - 1][assertion.priority_index]) except Exception as e: print(e) + if idx > idx_insert: tmp = assertion.policy[i] assertion.policy[i] = assertion.policy[i - 1] assertion.policy[i - 1] = tmp else: break + assertion.policy_map[DEFAULT_SEP.join(rule)] = i + except Exception as e: print(e) @@ -136,9 +148,11 @@ def add_policy(self, sec, ptype, rule): def add_policies(self, sec, ptype, rules): """adds policy rules to the model.""" + for rule in rules: if self.has_policy(sec, ptype, rule): return False + for rule in rules: if not self.add_policy(sec, ptype, rule): return False @@ -158,13 +172,13 @@ def update_policy(self, sec, ptype, old_rule, new_rule): else: return False - if "p_priority" in ast.tokens: + if ast.tokens and "p_priority" in ast.tokens: priority_index = ast.tokens.index("p_priority") if old_rule[priority_index] != new_rule[priority_index]: raise Exception("New rule should have the same priority with old rule.") - # 替换列表中的规则 + ast.policy[rule_index] = new_rule - # 更新映射:删除旧键,添加新键 + old_key = DEFAULT_SEP.join(old_rule) new_key = DEFAULT_SEP.join(new_rule) if old_key in ast.policy_map: @@ -173,8 +187,11 @@ def update_policy(self, sec, ptype, old_rule, new_rule): return True + + def update_policies(self, sec, ptype, old_rules, new_rules): - """update policy rules from the model.""" + """update policy rules from the model using update_policy for each rule. + If any update fails, roll back all changes.""" if sec not in self.keys(): return False if ptype not in self[sec]: @@ -183,36 +200,20 @@ def update_policies(self, sec, ptype, old_rules, new_rules): return False ast = self[sec][ptype] - old_rules_index = [] - for old_rule in old_rules: - if old_rule in ast.policy: - old_rules_index.append(ast.policy.index(old_rule)) - else: - return False + original_policy = [rule[:] for rule in ast.policy] + original_policy_map = ast.policy_map.copy() - if "p_priority" in ast.tokens: - priority_index = ast.tokens.index("p_priority") - for idx, old_rule, new_rule in zip(old_rules_index, old_rules, new_rules): - if old_rule[priority_index] != new_rule[priority_index]: - raise Exception("New rule should have the same priority with old rule.") - ast.policy[idx] = new_rule - old_key = DEFAULT_SEP.join(old_rule) - new_key = DEFAULT_SEP.join(new_rule) - if old_key in ast.policy_map: - del ast.policy_map[old_key] - ast.policy_map[new_key] = idx - else: - for idx, old_rule, new_rule in zip(old_rules_index, old_rules, new_rules): - ast.policy[idx] = new_rule - old_key = DEFAULT_SEP.join(old_rule) - new_key = DEFAULT_SEP.join(new_rule) - if old_key in ast.policy_map: - del ast.policy_map[old_key] - ast.policy_map[new_key] = idx + for old_rule, new_rule in zip(old_rules, new_rules): + if not self.update_policy(sec, ptype, old_rule, new_rule): + ast.policy = original_policy + ast.policy_map = original_policy_map + return False return True + + def remove_policy(self, sec, ptype, rule): """removes a policy rule from the model.""" if not self.has_policy(sec, ptype, rule): @@ -220,7 +221,7 @@ def remove_policy(self, sec, ptype, rule): assertion = self[sec][ptype] assertion.policy.remove(rule) - # 重新构建映射 + new_map = {} for idx, r in enumerate(assertion.policy): new_map[DEFAULT_SEP.join(r)] = idx @@ -228,34 +229,28 @@ def remove_policy(self, sec, ptype, rule): return rule not in assertion.policy + def remove_policies(self, sec, ptype, rules): - """RemovePolicies removes policy rules from the model.""" - assertion = self[sec][ptype] + """Remove multiple policy rules by sequentially calling remove_policy.""" for rule in rules: - if not self.has_policy(sec, ptype, rule): + if not self.remove_policy(sec, ptype, rule): return False - assertion.policy.remove(rule) - if rule in assertion.policy: - return False - # 重新构建映射 - new_map = {} - for idx, r in enumerate(assertion.policy): - new_map[DEFAULT_SEP.join(r)] = idx - assertion.policy_map = new_map - return True + def remove_policies_with_effected(self, sec, ptype, rules): effected = [] for rule in rules: if self.has_policy(sec, ptype, rule): effected.append(rule) self.remove_policy(sec, ptype, rule) + return effected def remove_filtered_policy_returns_effects(self, sec, ptype, field_index, *field_values): """ remove_filtered_policy_returns_effects removes policy rules based on field filters from the model. + Returns a list of rules that were removed. """ tmp = [] effects = [] @@ -275,7 +270,7 @@ def remove_filtered_policy_returns_effects(self, sec, ptype, field_index, *field assertion = self[sec][ptype] assertion.policy = tmp - # 重新构建映射 + new_map = {} for idx, r in enumerate(assertion.policy): new_map[DEFAULT_SEP.join(r)] = idx @@ -283,6 +278,7 @@ def remove_filtered_policy_returns_effects(self, sec, ptype, field_index, *field return effects + def remove_filtered_policy(self, sec, ptype, field_index, *field_values): """removes policy rules based on field filters from the model.""" tmp = [] @@ -301,7 +297,7 @@ def remove_filtered_policy(self, sec, ptype, field_index, *field_values): assertion = self[sec][ptype] assertion.policy = tmp - # 重新构建映射 + new_map = {} for idx, r in enumerate(assertion.policy): new_map[DEFAULT_SEP.join(r)] = idx @@ -309,6 +305,7 @@ def remove_filtered_policy(self, sec, ptype, field_index, *field_values): return res + def get_values_for_field_in_policy(self, sec, ptype, field_index): """gets all values for a field for all rules in a policy, duplicated values are removed.""" values = [] @@ -316,8 +313,10 @@ def get_values_for_field_in_policy(self, sec, ptype, field_index): return values if ptype not in self[sec]: return values + for rule in self[sec][ptype].policy: value = rule[field_index] if value not in values: values.append(value) - return values + + return values \ No newline at end of file From 11079f33e1d74ab86ba3f4e8dfc537969322a2e4 Mon Sep 17 00:00:00 2001 From: Victor Bayim <2506976721@qq.com> Date: Sat, 22 Mar 2025 16:30:08 +0800 Subject: [PATCH 3/3] fix(policy): synchronize policy_map update --- casbin/model/policy.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/casbin/model/policy.py b/casbin/model/policy.py index b9f3933..d3da6b9 100644 --- a/casbin/model/policy.py +++ b/casbin/model/policy.py @@ -172,11 +172,6 @@ def update_policy(self, sec, ptype, old_rule, new_rule): else: return False - if ast.tokens and "p_priority" in ast.tokens: - priority_index = ast.tokens.index("p_priority") - if old_rule[priority_index] != new_rule[priority_index]: - raise Exception("New rule should have the same priority with old rule.") - ast.policy[rule_index] = new_rule old_key = DEFAULT_SEP.join(old_rule) @@ -187,8 +182,6 @@ def update_policy(self, sec, ptype, old_rule, new_rule): return True - - def update_policies(self, sec, ptype, old_rules, new_rules): """update policy rules from the model using update_policy for each rule. If any update fails, roll back all changes.""" @@ -212,8 +205,6 @@ def update_policies(self, sec, ptype, old_rules, new_rules): return True - - def remove_policy(self, sec, ptype, rule): """removes a policy rule from the model.""" if not self.has_policy(sec, ptype, rule): @@ -229,7 +220,6 @@ def remove_policy(self, sec, ptype, rule): return rule not in assertion.policy - def remove_policies(self, sec, ptype, rules): """Remove multiple policy rules by sequentially calling remove_policy.""" for rule in rules: @@ -237,7 +227,6 @@ def remove_policies(self, sec, ptype, rules): return False return True - def remove_policies_with_effected(self, sec, ptype, rules): effected = [] for rule in rules: @@ -270,7 +259,7 @@ def remove_filtered_policy_returns_effects(self, sec, ptype, field_index, *field assertion = self[sec][ptype] assertion.policy = tmp - + new_map = {} for idx, r in enumerate(assertion.policy): new_map[DEFAULT_SEP.join(r)] = idx @@ -278,7 +267,6 @@ def remove_filtered_policy_returns_effects(self, sec, ptype, field_index, *field return effects - def remove_filtered_policy(self, sec, ptype, field_index, *field_values): """removes policy rules based on field filters from the model.""" tmp = [] @@ -297,7 +285,7 @@ def remove_filtered_policy(self, sec, ptype, field_index, *field_values): assertion = self[sec][ptype] assertion.policy = tmp - + new_map = {} for idx, r in enumerate(assertion.policy): new_map[DEFAULT_SEP.join(r)] = idx @@ -305,7 +293,6 @@ def remove_filtered_policy(self, sec, ptype, field_index, *field_values): return res - def get_values_for_field_in_policy(self, sec, ptype, field_index): """gets all values for a field for all rules in a policy, duplicated values are removed.""" values = [] @@ -319,4 +306,4 @@ def get_values_for_field_in_policy(self, sec, ptype, field_index): if value not in values: values.append(value) - return values \ No newline at end of file + return values