|
1 | | -from typing import override |
| 1 | +from typing import cast, override |
2 | 2 |
|
3 | 3 | from ai.backend.manager.actions.monitors.monitor import ActionMonitor |
4 | 4 | from ai.backend.manager.actions.processor import ActionProcessor |
5 | 5 | from ai.backend.manager.actions.types import AbstractProcessorPackage, ActionSpec |
| 6 | +from ai.backend.manager.actions.validator.base import ActionValidator |
| 7 | +from ai.backend.manager.actions.validators.rbac.scope import ScopeActionRBACValidator |
| 8 | +from ai.backend.manager.actions.validators.rbac.single_entity import SingleEntityActionRBACValidator |
| 9 | +from ai.backend.manager.repositories.permission_controller.repository import ( |
| 10 | + PermissionControllerRepository, |
| 11 | +) |
6 | 12 | from ai.backend.manager.services.session.actions.check_and_transit_status import ( |
7 | 13 | CheckAndTransitStatusAction, |
8 | 14 | CheckAndTransitStatusActionResult, |
@@ -165,43 +171,95 @@ class SessionProcessors(AbstractProcessorPackage): |
165 | 171 | CheckAndTransitStatusAction, CheckAndTransitStatusActionResult |
166 | 172 | ] |
167 | 173 |
|
168 | | - def __init__(self, service: SessionService, action_monitors: list[ActionMonitor]) -> None: |
| 174 | + def __init__( |
| 175 | + self, |
| 176 | + service: SessionService, |
| 177 | + action_monitors: list[ActionMonitor], |
| 178 | + permission_repository: PermissionControllerRepository, |
| 179 | + ) -> None: |
| 180 | + # Create RBAC validators |
| 181 | + scope_validator = ScopeActionRBACValidator(permission_repository) |
| 182 | + single_entity_validator = SingleEntityActionRBACValidator(permission_repository) |
| 183 | + |
| 184 | + # Actions without RBAC validation (internal/legacy) |
169 | 185 | self.commit_session = ActionProcessor(service.commit_session, action_monitors) |
170 | 186 | self.complete = ActionProcessor(service.complete, action_monitors) |
171 | 187 | self.convert_session_to_image = ActionProcessor( |
172 | 188 | service.convert_session_to_image, action_monitors |
173 | 189 | ) |
174 | | - self.create_cluster = ActionProcessor(service.create_cluster, action_monitors) |
175 | | - self.create_from_params = ActionProcessor(service.create_from_params, action_monitors) |
176 | | - self.create_from_template = ActionProcessor(service.create_from_template, action_monitors) |
177 | | - self.destroy_session = ActionProcessor(service.destroy_session, action_monitors) |
178 | 190 | self.download_file = ActionProcessor(service.download_file, action_monitors) |
179 | 191 | self.download_files = ActionProcessor(service.download_files, action_monitors) |
180 | | - self.execute_session = ActionProcessor(service.execute_session, action_monitors) |
181 | 192 | self.get_abusing_report = ActionProcessor(service.get_abusing_report, action_monitors) |
182 | 193 | self.get_commit_status = ActionProcessor(service.get_commit_status, action_monitors) |
183 | 194 | self.get_container_logs = ActionProcessor(service.get_container_logs, action_monitors) |
184 | 195 | self.get_dependency_graph = ActionProcessor(service.get_dependency_graph, action_monitors) |
185 | 196 | self.get_direct_access_info = ActionProcessor( |
186 | 197 | service.get_direct_access_info, action_monitors |
187 | 198 | ) |
188 | | - self.get_session_info = ActionProcessor(service.get_session_info, action_monitors) |
189 | 199 | self.get_status_history = ActionProcessor(service.get_status_history, action_monitors) |
190 | 200 | self.interrupt = ActionProcessor(service.interrupt, action_monitors) |
191 | 201 | self.list_files = ActionProcessor(service.list_files, action_monitors) |
192 | | - self.match_sessions = ActionProcessor(service.match_sessions, action_monitors) |
193 | 202 | self.rename_session = ActionProcessor(service.rename_session, action_monitors) |
194 | 203 | self.restart_session = ActionProcessor(service.restart_session, action_monitors) |
195 | | - self.search_kernels = ActionProcessor(service.search_kernels, action_monitors) |
196 | | - self.search_sessions = ActionProcessor(service.search, action_monitors) |
197 | 204 | self.shutdown_service = ActionProcessor(service.shutdown_service, action_monitors) |
198 | 205 | self.start_service = ActionProcessor(service.start_service, action_monitors) |
199 | 206 | self.upload_files = ActionProcessor(service.upload_files, action_monitors) |
200 | | - self.modify_session = ActionProcessor(service.modify_session, action_monitors) |
201 | 207 | self.check_and_transit_status = ActionProcessor( |
202 | 208 | service.check_and_transit_status, action_monitors |
203 | 209 | ) |
204 | 210 |
|
| 211 | + # Scope actions with RBAC validation |
| 212 | + self.create_cluster = ActionProcessor( |
| 213 | + service.create_cluster, |
| 214 | + action_monitors, |
| 215 | + validators=[cast(ActionValidator, scope_validator)], |
| 216 | + ) |
| 217 | + self.create_from_params = ActionProcessor( |
| 218 | + service.create_from_params, |
| 219 | + action_monitors, |
| 220 | + validators=[cast(ActionValidator, scope_validator)], |
| 221 | + ) |
| 222 | + self.create_from_template = ActionProcessor( |
| 223 | + service.create_from_template, |
| 224 | + action_monitors, |
| 225 | + validators=[cast(ActionValidator, scope_validator)], |
| 226 | + ) |
| 227 | + self.match_sessions = ActionProcessor( |
| 228 | + service.match_sessions, |
| 229 | + action_monitors, |
| 230 | + validators=[cast(ActionValidator, scope_validator)], |
| 231 | + ) |
| 232 | + self.search_kernels = ActionProcessor( |
| 233 | + service.search_kernels, |
| 234 | + action_monitors, |
| 235 | + validators=[cast(ActionValidator, scope_validator)], |
| 236 | + ) |
| 237 | + self.search_sessions = ActionProcessor( |
| 238 | + service.search, action_monitors, validators=[cast(ActionValidator, scope_validator)] |
| 239 | + ) |
| 240 | + |
| 241 | + # Single entity actions with RBAC validation |
| 242 | + self.destroy_session = ActionProcessor( |
| 243 | + service.destroy_session, |
| 244 | + action_monitors, |
| 245 | + validators=[cast(ActionValidator, single_entity_validator)], |
| 246 | + ) |
| 247 | + self.execute_session = ActionProcessor( |
| 248 | + service.execute_session, |
| 249 | + action_monitors, |
| 250 | + validators=[cast(ActionValidator, single_entity_validator)], |
| 251 | + ) |
| 252 | + self.get_session_info = ActionProcessor( |
| 253 | + service.get_session_info, |
| 254 | + action_monitors, |
| 255 | + validators=[cast(ActionValidator, single_entity_validator)], |
| 256 | + ) |
| 257 | + self.modify_session = ActionProcessor( |
| 258 | + service.modify_session, |
| 259 | + action_monitors, |
| 260 | + validators=[cast(ActionValidator, single_entity_validator)], |
| 261 | + ) |
| 262 | + |
205 | 263 | @override |
206 | 264 | def supported_actions(self) -> list[ActionSpec]: |
207 | 265 | return [ |
|
0 commit comments