Skip to content

Commit 888f779

Browse files
fix: authority retrieval by mapping per authority and not per directives name. (#201)
Signed-off-by: zufardhiyaulhaq <zufardhiyaulhaq@gmail.com> Co-authored-by: zufardhiyaulhaq <zufardhiyaulhaq@gmail.com>
1 parent b8d727f commit 888f779

File tree

3 files changed

+149
-29
lines changed

3 files changed

+149
-29
lines changed

main_test.go

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,99 @@ func TestBadResponse(t *testing.T) {
661661
})
662662
}
663663

664+
func TestPerAuthorityDirectives(t *testing.T) {
665+
tests := []struct {
666+
name string
667+
reqHdrs [][2]string
668+
conf string
669+
localResponseIsNil bool
670+
localResponseStatusCode int
671+
}{
672+
{
673+
name: "authority exist on per_authority_directives",
674+
reqHdrs: [][2]string{
675+
{":path", "/rs1"},
676+
{":method", "GET"},
677+
{":authority", "foo.example.com"},
678+
},
679+
conf: `{"directives_map": {"default": ["SecRuleEngine On","SecRule REQUEST_URI \"@streq /admin\" \"id:101,phase:1,t:lowercase,deny\""], "rs1": ["SecRuleEngine On","SecRule REQUEST_URI \"@streq /rs1\" \"id:101,phase:1,t:lowercase,deny\""]}, "default_directives": "default", "per_authority_directives":{"foo.example.com":"rs1"}}`,
680+
localResponseStatusCode: 403,
681+
},
682+
{
683+
name: "authority exist on per_authority_directives but calling allowed path",
684+
reqHdrs: [][2]string{
685+
{":path", "/admin"},
686+
{":method", "GET"},
687+
{":authority", "foo.example.com"},
688+
},
689+
conf: `{"directives_map": {"default": ["SecRuleEngine On","SecRule REQUEST_URI \"@streq /admin\" \"id:101,phase:1,t:lowercase,deny\""], "rs1": ["SecRuleEngine On","SecRule REQUEST_URI \"@streq /rs1\" \"id:101,phase:1,t:lowercase,deny\""]}, "default_directives": "default", "per_authority_directives":{"foo.example.com":"rs1"}}`,
690+
localResponseIsNil: true,
691+
},
692+
{
693+
name: "authority not exist on per_authority_directives",
694+
reqHdrs: [][2]string{
695+
{":path", "/admin"},
696+
{":method", "GET"},
697+
{":authority", "bar.example.com"},
698+
},
699+
conf: `{"directives_map": {"default": ["SecRuleEngine On","SecRule REQUEST_URI \"@streq /admin\" \"id:101,phase:1,t:lowercase,deny\""], "rs1": ["SecRuleEngine On","SecRule REQUEST_URI \"@streq /rs1\" \"id:101,phase:1,t:lowercase,deny\""]}, "default_directives": "default", "per_authority_directives":{"foo.example.com":"rs1"}}`,
700+
localResponseStatusCode: 403,
701+
},
702+
{
703+
name: "authority not exist on per_authority_directives and no default",
704+
reqHdrs: [][2]string{
705+
{":path", "/admin"},
706+
{":method", "GET"},
707+
{":authority", "bar.example.com"},
708+
},
709+
conf: `{"directives_map": {"rs1": ["SecRuleEngine On","SecRule REQUEST_URI \"@streq /rs1\" \"id:101,phase:1,t:lowercase,deny\""]}, "per_authority_directives":{"foo.example.com":"rs1"}}`,
710+
localResponseIsNil: true,
711+
},
712+
{
713+
name: "authority not exist on per_authority_directives but calling allowed value",
714+
reqHdrs: [][2]string{
715+
{":path", "/rs1"},
716+
{":method", "GET"},
717+
{":authority", "bar.example.com"},
718+
},
719+
conf: `{"directives_map": {"default": ["SecRuleEngine On","SecRule REQUEST_URI \"@streq /admin\" \"id:101,phase:1,t:lowercase,deny\""], "rs1": ["SecRuleEngine On","SecRule REQUEST_URI \"@streq /rs1\" \"id:101,phase:1,t:lowercase,deny\""]}, "default_directives": "default", "per_authority_directives":{"foo.example.com":"rs1"}}`,
720+
localResponseIsNil: true,
721+
},
722+
}
723+
724+
vmTest(t, func(t *testing.T, vm types.VMContext) {
725+
for _, tc := range tests {
726+
tt := tc
727+
t.Run(tt.name, func(t *testing.T) {
728+
opt := proxytest.
729+
NewEmulatorOption().
730+
WithVMContext(vm).
731+
WithPluginConfiguration([]byte(tt.conf))
732+
733+
host, reset := proxytest.NewHostEmulator(opt)
734+
defer reset()
735+
736+
require.Equal(t, types.OnPluginStartStatusOK, host.StartPlugin())
737+
738+
id := host.InitializeHttpContext()
739+
740+
host.CallOnRequestHeaders(id, tt.reqHdrs, false)
741+
host.CompleteHttpContext(id)
742+
743+
pluginResp := host.GetSentLocalResponse(id)
744+
745+
if tt.localResponseIsNil {
746+
require.Nil(t, pluginResp)
747+
return
748+
}
749+
750+
require.NotNil(t, pluginResp)
751+
require.EqualValues(t, tt.localResponseStatusCode, pluginResp.StatusCode)
752+
})
753+
}
754+
})
755+
}
756+
664757
func TestEmptyBody(t *testing.T) {
665758
vmTest(t, func(t *testing.T, vm types.VMContext) {
666759
opt := proxytest.

wasmplugin/config_test.go

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -250,18 +250,12 @@ func TestWAFMap(t *testing.T) {
250250
err := wm.put("foo", w)
251251
require.NoError(t, err)
252252

253-
t.Run("set unexisting default key", func(t *testing.T) {
254-
err = wm.setDefaultKey("bar")
255-
require.Error(t, err)
256-
})
257-
258253
t.Run("get unexisting WAF with no default", func(t *testing.T) {
259254
_, _, err := wm.getWAFOrDefault("bar")
260255
require.Error(t, err)
261256
})
262257

263-
err = wm.setDefaultKey("foo")
264-
require.NoError(t, err)
258+
wm.setDefaultWAF(w)
265259

266260
t.Run("get existing WAF", func(t *testing.T) {
267261
expecteWAF, isDefault, err := wm.getWAFOrDefault("foo")

wasmplugin/plugin.go

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ func (*vmContext) NewPluginContext(contextID uint32) types.PluginContext {
3636

3737
type wafMap struct {
3838
kv map[string]coraza.WAF
39-
defaultKey string
39+
defaultWAF coraza.WAF
4040
}
4141

4242
func newWAFMap(capacity int) wafMap {
@@ -54,29 +54,23 @@ func (m *wafMap) put(key string, waf coraza.WAF) error {
5454
return nil
5555
}
5656

57-
func (m *wafMap) setDefaultKey(key string) error {
58-
if len(key) == 0 {
59-
return errors.New("empty default WAF key")
60-
}
61-
62-
if _, ok := m.kv[key]; ok {
63-
m.defaultKey = key
64-
return nil
57+
func (m *wafMap) setDefaultWAF(w coraza.WAF) {
58+
if w == nil {
59+
panic("nil WAF set as default")
6560
}
66-
67-
return fmt.Errorf("unknown default WAF key %q", key)
61+
m.defaultWAF = w
6862
}
6963

7064
func (m *wafMap) getWAFOrDefault(key string) (coraza.WAF, bool, error) {
7165
if w, ok := m.kv[key]; ok {
7266
return w, false, nil
7367
}
7468

75-
if len(m.defaultKey) == 0 {
76-
return nil, false, errors.New("no default WAF key")
69+
if m.defaultWAF == nil {
70+
return nil, false, errors.New("no default WAF")
7771
}
7872

79-
return m.kv[m.defaultKey], true, nil
73+
return m.defaultWAF, true, nil
8074
}
8175

8276
type corazaPlugin struct {
@@ -100,8 +94,33 @@ func (ctx *corazaPlugin) OnPluginStart(pluginConfigurationSize int) types.OnPlug
10094
return types.OnPluginStartStatusFailed
10195
}
10296

97+
// directivesAuthoritesMap is a map of directives name to the list of
98+
// authorities that reference those directives. This is used to
99+
// initialize the WAFs only for the directives that are referenced
100+
directivesAuthoritiesMap := map[string][]string{}
101+
for authority, directivesName := range config.perAuthorityDirectives {
102+
directivesAuthoritiesMap[directivesName] = append(directivesAuthoritiesMap[directivesName], authority)
103+
}
104+
103105
perAuthorityWAFs := newWAFMap(len(config.directivesMap))
104106
for name, directives := range config.directivesMap {
107+
var authorities []string
108+
109+
// if the name of the directives is the default directives, we
110+
// initialize the WAF despite the fact that it is not associated
111+
// to any authority. This is because we need to initialize the
112+
// default WAF for requests that don't belong to any authority.
113+
if name != config.defaultDirectives {
114+
var directivesFound bool
115+
authorities, directivesFound = directivesAuthoritiesMap[name]
116+
if !directivesFound {
117+
// if no directives found as key, no authority references
118+
// these directives and hence we won't initialize them as
119+
// it will be a waste of resources.
120+
continue
121+
}
122+
}
123+
105124
// First we initialize our waf and our seclang parser
106125
conf := coraza.NewWAFConfig().
107126
WithErrorCallback(logError).
@@ -119,18 +138,32 @@ func (ctx *corazaPlugin) OnPluginStart(pluginConfigurationSize int) types.OnPlug
119138
return types.OnPluginStartStatusFailed
120139
}
121140

122-
err = perAuthorityWAFs.put(name, waf)
123-
if err != nil {
124-
proxywasm.LogCriticalf("Failed to register authority WAF: %v", err)
125-
return types.OnPluginStartStatusFailed
141+
if len(authorities) == 0 {
142+
// if no authorities are associated directly with this WAF
143+
// but we still initialize it, it means this is the default
144+
// one.
145+
perAuthorityWAFs.setDefaultWAF(waf)
146+
}
147+
148+
for _, authority := range authorities {
149+
err = perAuthorityWAFs.put(authority, waf)
150+
if err != nil {
151+
proxywasm.LogCriticalf("Failed to register authority WAF: %v", err)
152+
return types.OnPluginStartStatusFailed
153+
}
126154
}
155+
156+
delete(directivesAuthoritiesMap, name)
127157
}
128158

129-
if len(config.defaultDirectives) > 0 {
130-
if err := perAuthorityWAFs.setDefaultKey(config.defaultDirectives); err != nil {
131-
proxywasm.LogCriticalf("Failed to set the default directives: %v", err)
132-
return types.OnPluginStartStatusFailed
159+
if len(directivesAuthoritiesMap) > 0 {
160+
// if there are directives remaining in the directivesAuthoritiesMap, means
161+
// those directives weren't part of the directivesMap and hence not declared.
162+
for unknownDirective := range directivesAuthoritiesMap {
163+
proxywasm.LogCriticalf("Unknown directives %q", unknownDirective)
133164
}
165+
166+
return types.OnPluginStartStatusFailed
134167
}
135168

136169
ctx.perAuthorityWAFs = perAuthorityWAFs

0 commit comments

Comments
 (0)