Skip to content

fix(policy): update policy_map by calling add_policy #377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 43 additions & 40 deletions casbin/model/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,12 @@ def add_policies(self, sec, ptype, rules):
return False

for rule in rules:
self[sec][ptype].policy.append(rule)

if not self.add_policy(sec, ptype, rule):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should fix all similar issues, including but not limited to: update_policies, remove_policies, remove_filtered_policy

return False
return True

def update_policy(self, sec, ptype, old_rule, new_rule):
"""update a policy rule from the model."""

if sec not in self.keys():
return False
if ptype not in self[sec]:
Expand All @@ -173,20 +172,19 @@ def update_policy(self, sec, ptype, old_rule, new_rule):
else:
return False

if "p_priority" in ast.tokens:
priority_index = ast.tokens.index("p_priority")
if old_rule[priority_index] == new_rule[priority_index]:
ast.policy[rule_index] = new_rule
else:
raise Exception("New rule should have the same priority with old rule.")
else:
ast.policy[rule_index] = new_rule
ast.policy[rule_index] = new_rule

old_key = DEFAULT_SEP.join(old_rule)
new_key = DEFAULT_SEP.join(new_rule)
if old_key in ast.policy_map:
del ast.policy_map[old_key]
ast.policy_map[new_key] = rule_index

return True

def update_policies(self, sec, ptype, old_rules, new_rules):
"""update policy rules from the model."""

"""update policy rules from the model using update_policy for each rule.
If any update fails, roll back all changes."""
if sec not in self.keys():
return False
if ptype not in self[sec]:
Expand All @@ -195,24 +193,15 @@ def update_policies(self, sec, ptype, old_rules, new_rules):
return False

ast = self[sec][ptype]
old_rules_index = []

for old_rule in old_rules:
if old_rule in ast.policy:
old_rules_index.append(ast.policy.index(old_rule))
else:
return False
original_policy = [rule[:] for rule in ast.policy]
original_policy_map = ast.policy_map.copy()

if "p_priority" in ast.tokens:
priority_index = ast.tokens.index("p_priority")
for idx, old_rule, new_rule in zip(old_rules_index, old_rules, new_rules):
if old_rule[priority_index] == new_rule[priority_index]:
ast.policy[idx] = new_rule
else:
raise Exception("New rule should have the same priority with old rule.")
else:
for idx, old_rule, new_rule in zip(old_rules_index, old_rules, new_rules):
ast.policy[idx] = new_rule
for old_rule, new_rule in zip(old_rules, new_rules):
if not self.update_policy(sec, ptype, old_rule, new_rule):
ast.policy = original_policy
ast.policy_map = original_policy_map
return False

return True

Expand All @@ -221,20 +210,21 @@ def remove_policy(self, sec, ptype, rule):
if not self.has_policy(sec, ptype, rule):
return False

self[sec][ptype].policy.remove(rule)
assertion = self[sec][ptype]
assertion.policy.remove(rule)

return rule not in self[sec][ptype].policy
new_map = {}
for idx, r in enumerate(assertion.policy):
new_map[DEFAULT_SEP.join(r)] = idx
assertion.policy_map = new_map

def remove_policies(self, sec, ptype, rules):
"""RemovePolicies removes policy rules from the model."""
return rule not in assertion.policy

def remove_policies(self, sec, ptype, rules):
"""Remove multiple policy rules by sequentially calling remove_policy."""
for rule in rules:
if not self.has_policy(sec, ptype, rule):
return False
self[sec][ptype].policy.remove(rule)
if rule in self[sec][ptype].policy:
if not self.remove_policy(sec, ptype, rule):
return False

return True

def remove_policies_with_effected(self, sec, ptype, rules):
Expand All @@ -249,6 +239,7 @@ def remove_policies_with_effected(self, sec, ptype, rules):
def remove_filtered_policy_returns_effects(self, sec, ptype, field_index, *field_values):
"""
remove_filtered_policy_returns_effects removes policy rules based on field filters from the model.
Returns a list of rules that were removed.
"""
tmp = []
effects = []
Expand All @@ -266,7 +257,13 @@ def remove_filtered_policy_returns_effects(self, sec, ptype, field_index, *field
else:
tmp.append(rule)

self[sec][ptype].policy = tmp
assertion = self[sec][ptype]
assertion.policy = tmp

new_map = {}
for idx, r in enumerate(assertion.policy):
new_map[DEFAULT_SEP.join(r)] = idx
assertion.policy_map = new_map

return effects

Expand All @@ -286,7 +283,13 @@ def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
else:
tmp.append(rule)

self[sec][ptype].policy = tmp
assertion = self[sec][ptype]
assertion.policy = tmp

new_map = {}
for idx, r in enumerate(assertion.policy):
new_map[DEFAULT_SEP.join(r)] = idx
assertion.policy_map = new_map

return res

Expand Down