Skip to content

Commit f4b0816

Browse files
address comments
Signed-off-by: AmoebaProtozoa <8039876+AmoebaProtozoa@users.noreply.github.com>
1 parent 13a5fbe commit f4b0816

3 files changed

Lines changed: 87 additions & 5 deletions

File tree

pkg/util/sem/v2/restricted_statement.go

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ func isRestrictedRole(userName, hostname string) bool {
4040

4141
// IsRestrictedStatement returns a non-nil error when strict SEM forbids stmt.
4242
// It runs the check unconditionally; callers gate on IsStrictEnabled.
43+
//
44+
// The switch is a default-deny allow-list: any StmtNode not matched below is
45+
// rejected by the trailing notSupported call. When adding a new statement
46+
// type, decide explicitly whether it should be allowed or rejected.
4347
func IsRestrictedStatement(stmt ast.Node) error {
4448
switch x := stmt.(type) {
4549
case *ast.DeallocateStmt,
@@ -62,7 +66,28 @@ func IsRestrictedStatement(stmt ast.Node) error {
6266
*ast.CreateBindingStmt,
6367
*ast.DropBindingStmt,
6468
*ast.SetBindingStmt,
65-
*ast.CompactTableStmt:
69+
*ast.CompactTableStmt,
70+
// Carried over from the previous default-allow fall-through. They are
71+
// preserved as allow here to keep behavior identical to before the
72+
// switch was tightened to default-deny; revisit each in a follow-up.
73+
*ast.AddQueryWatchStmt,
74+
*ast.AlterRangeStmt,
75+
*ast.CalibrateResourceStmt,
76+
*ast.CallStmt,
77+
*ast.CancelDistributionJobStmt,
78+
*ast.CreateStatisticsStmt,
79+
*ast.DistributeTableStmt,
80+
*ast.DropProcedureStmt,
81+
*ast.DropQueryWatchStmt,
82+
*ast.DropStatisticsStmt,
83+
*ast.GrantProxyStmt,
84+
*ast.HelpStmt,
85+
*ast.ImportIntoActionStmt,
86+
*ast.ImportIntoStmt,
87+
*ast.RecommendIndexStmt,
88+
*ast.RefreshStatsStmt,
89+
*ast.RestartStmt,
90+
*ast.TrafficStmt:
6691
return nil
6792
case *ast.LoadDataStmt:
6893
return verifyLoadData(x)
@@ -87,7 +112,7 @@ func IsRestrictedStatement(stmt ast.Node) error {
87112
return notSupported("SPLIT REGION")
88113
}
89114

90-
return nil
115+
return notSupported(fmt.Sprintf("Unsupported statement: %T", stmt))
91116
}
92117

93118
func verifyDDL(stmt ast.DDLNode) error {
@@ -158,7 +183,6 @@ func verifySimple(stmt ast.Node) error {
158183
*ast.KillStmt,
159184
*ast.BinlogStmt,
160185
*ast.DropStatsStmt,
161-
*ast.SetDefaultRoleStmt,
162186
*ast.AdminStmt,
163187
*ast.GrantStmt,
164188
*ast.RevokeStmt,
@@ -194,12 +218,50 @@ func verifySimple(stmt ast.Node) error {
194218
}
195219
return nil
196220
case *ast.SetRoleStmt:
221+
// SET ROLE DEFAULT|ALL|ALL EXCEPT can implicitly activate a granted
222+
// restricted role (cloud_admin), bypassing the RoleList-only check.
223+
// Reject those wildcard forms; allow NONE (deactivates everything)
224+
// and the regular form after the RoleList contains no restricted role.
225+
switch s.SetRoleOpt {
226+
case ast.SetRoleNone:
227+
return nil
228+
case ast.SetRoleDefault:
229+
return notSupported("SET ROLE DEFAULT")
230+
case ast.SetRoleAll:
231+
return notSupported("SET ROLE ALL")
232+
case ast.SetRoleAllExcept:
233+
return notSupported("SET ROLE ALL EXCEPT")
234+
}
197235
for _, role := range s.RoleList {
198236
if isRestrictedRole(role.Username, role.Hostname) {
199237
return notSupported(fmt.Sprintf("SET ROLE %s", role))
200238
}
201239
}
202240
return nil
241+
case *ast.SetDefaultRoleStmt:
242+
// SET DEFAULT ROLE is the matching write-side: it persists the role
243+
// list a user activates via SET ROLE DEFAULT. Block any attempt to
244+
// (a) target a restricted account (root, cloud_admin) so its default
245+
// role set cannot be tampered with, (b) install a restricted role as
246+
// someone else's default, and (c) use ALL, which would silently
247+
// include cloud_admin when it has been granted to the target.
248+
for _, user := range s.UserList {
249+
if isRestrictedUser(user.Username, user.Hostname) {
250+
return notSupported(fmt.Sprintf("SET DEFAULT ROLE TO %s", user))
251+
}
252+
}
253+
switch s.SetRoleOpt {
254+
case ast.SetRoleNone:
255+
return nil
256+
case ast.SetRoleAll:
257+
return notSupported("SET DEFAULT ROLE ALL")
258+
}
259+
for _, role := range s.RoleList {
260+
if isRestrictedRole(role.Username, role.Hostname) {
261+
return notSupported(fmt.Sprintf("SET DEFAULT ROLE %s", role))
262+
}
263+
}
264+
return nil
203265
case *ast.AlterInstanceStmt:
204266
return notSupported("ALTER INSTANCE")
205267
case *ast.ShutdownStmt:

pkg/util/sem/v2/restricted_statement_test.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,25 @@ func TestStatement_RestrictedUserOperations(t *testing.T) {
147147
mustReject(t, `SET ROLE 'cloud_admin'@'%'`, "SET ROLE")
148148
mustPass(t, `SET ROLE 'role1'@'%'`)
149149

150+
// SET ROLE wildcard forms can implicitly activate cloud_admin if granted,
151+
// so they are blocked outright. NONE is safe (deactivates everything).
152+
mustReject(t, `SET ROLE DEFAULT`, "SET ROLE DEFAULT")
153+
mustReject(t, `SET ROLE ALL`, "SET ROLE ALL")
154+
mustReject(t, `SET ROLE ALL EXCEPT 'role1'@'%'`, "SET ROLE ALL EXCEPT")
155+
mustPass(t, `SET ROLE NONE`)
156+
157+
// SET DEFAULT ROLE writes the role list activated by SET ROLE DEFAULT, so
158+
// it must enforce the same protection on both ends: no restricted role can
159+
// be installed as a default, the target user list cannot include a
160+
// restricted account, and ALL is rejected because it would silently include
161+
// cloud_admin when granted.
162+
mustReject(t, `SET DEFAULT ROLE 'cloud_admin'@'%' TO 'alice'@'%'`, "SET DEFAULT ROLE")
163+
mustReject(t, `SET DEFAULT ROLE 'role1'@'%' TO 'root'@'%'`, "SET DEFAULT ROLE TO")
164+
mustReject(t, `SET DEFAULT ROLE 'role1'@'%' TO 'cloud_admin'@'%'`, "SET DEFAULT ROLE TO")
165+
mustReject(t, `SET DEFAULT ROLE ALL TO 'alice'@'%'`, "SET DEFAULT ROLE ALL")
166+
mustPass(t, `SET DEFAULT ROLE 'role1'@'%' TO 'alice'@'%'`)
167+
mustPass(t, `SET DEFAULT ROLE NONE TO 'alice'@'%'`)
168+
150169
mustPass(t, `CREATE USER 'x'@'%' IDENTIFIED BY 'pw'`)
151170
mustPass(t, `ALTER USER 'x'@'%' IDENTIFIED BY 'pw2'`)
152171
mustPass(t, `SET PASSWORD FOR 'x'@'%' = 'pw3'`)

tests/realtikvtest/pipelineddmltest/pipelineddml_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -345,11 +345,12 @@ func TestPipelinedDMLNegative(t *testing.T) {
345345
tk.MustQuery("show warnings").CheckContain("Pipelined DML can not be used when tidb_constraint_check_in_place=ON. Fallback to standard mode")
346346
tk.MustExec("set @@tidb_constraint_check_in_place = 0")
347347

348-
// strict SEM (starter/essential deployments)
348+
// strict SEM (starter/essential deployments). Use t.Cleanup so a mid-test
349+
// assertion failure cannot leak the global strict flag into later tests.
349350
semv2.EnableStrict()
351+
t.Cleanup(semv2.DisableStrict)
350352
tk.MustExec("insert into t values(12, 12)")
351353
tk.MustQuery("show warnings").CheckContain("Pipelined DML is not supported in this deployment. Fallback to standard mode")
352-
semv2.DisableStrict()
353354
}
354355

355356
func compareTables(t *testing.T, tk *testkit.TestKit, t1, t2 string) {

0 commit comments

Comments
 (0)