Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 41 additions & 3 deletions path_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"fmt"
"net/http"
"strings"
"time"

"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-hclog"
Expand All @@ -35,6 +36,11 @@ func (b *backend) pathLogin() *framework.Path {
Type: framework.TypeString,
Description: `SPNEGO Authorization header. Required.`,
},
"ttl": {
Type: framework.TypeString,
Description: "Optional custom TTL for the session token (e.g., '1m')",
Default: "",
},
},
Operations: map[logical.Operation]framework.OperationHandler{
logical.ReadOperation: &framework.PathOperation{
Expand Down Expand Up @@ -102,6 +108,7 @@ func (b *backend) pathLoginUpdate(ctx context.Context, req *logical.Request, d *
}
}

// Obtain the SPNEGO token from headers or JSON field
authorizationString := ""
authorizationHeaders := req.Headers["Authorization"]
if len(authorizationHeaders) > 0 {
Expand Down Expand Up @@ -163,7 +170,10 @@ func (b *backend) pathLoginUpdate(ctx context.Context, req *logical.Request, d *
// therefore policies, from a separate directory.
if ldapCfg.ConfigEntry.UPNDomain != "" && identity.Domain() != ldapCfg.ConfigEntry.UPNDomain {
w.WriteHeader(400)
_, _ = w.Write([]byte(fmt.Sprintf("identity domain of %q doesn't match LDAP upndomain of %q", identity.Domain(), ldapCfg.ConfigEntry.UPNDomain)))
_, _ = w.Write([]byte(fmt.Sprintf(
"identity domain of %q doesn't match LDAP upndomain of %q",
identity.Domain(), ldapCfg.ConfigEntry.UPNDomain,
)))
return
}
authenticated = true
Expand Down Expand Up @@ -230,7 +240,10 @@ func (b *backend) pathLoginUpdate(ctx context.Context, req *logical.Request, d *
if err != nil {
return nil, errwrap.Wrapf("unable to get user binddn: {{err}}", err)
}
b.Logger().Debug("auth/ldap: User BindDN fetched", "username", identity.UserName(), "binddn", userBindDN)
b.Logger().Debug("auth/ldap: User BindDN fetched",
"username", identity.UserName(),
"binddn", userBindDN,
)

userDN, err := ldapClient.GetUserDN(ldapCfg.ConfigEntry, ldapConnection, userBindDN, username)
if err != nil {
Expand All @@ -241,7 +254,10 @@ func (b *backend) pathLoginUpdate(ctx context.Context, req *logical.Request, d *
if err != nil {
return nil, errwrap.Wrapf("unable to get ldap groups: {{err}}", err)
}
b.Logger().Debug("auth/ldap: Groups fetched from server", "num_server_groups", len(ldapGroups), "server_groups", ldapGroups)
b.Logger().Debug("auth/ldap: Groups fetched from server",
"num_server_groups", len(ldapGroups),
"server_groups", ldapGroups,
)

var allGroups []string
// Merge local and LDAP groups
Expand Down Expand Up @@ -283,6 +299,18 @@ func (b *backend) pathLoginUpdate(ctx context.Context, req *logical.Request, d *
Renewable: false,
}

// If user set "ttl" in the JSON body, parse and apply it
if v, ok := d.GetOk("ttl"); ok {
ttlRaw := v.(string)
userTTL, err := parseTTL(ttlRaw)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("invalid ttl format: %v", err)), nil
}
if userTTL > 0 {
auth.LeaseOptions.TTL = userTTL
}
}

// Combine our policies with the ones parsed from PopulateTokenAuth.
if len(policies) > 0 {
auth.Policies = append(auth.Policies, policies...)
Expand All @@ -305,6 +333,16 @@ func (b *backend) pathLoginUpdate(ctx context.Context, req *logical.Request, d *
}, nil
}

// parseTTL parses a TTL string and returns the duration.
// Returns zero duration for empty string, error for invalid format.
func parseTTL(ttlRaw string) (time.Duration, error) {
if ttlRaw == "" {
return 0, nil
}
return time.ParseDuration(ttlRaw)
}

// simpleResponseWriter is used internally to capture SPNEGO authentication responses
type simpleResponseWriter struct {
body []byte
statusCode int
Expand Down
131 changes: 131 additions & 0 deletions path_login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"strings"
"testing"
"time"

"github.com/go-ldap/ldap/v3"
"github.com/hashicorp/vault/sdk/logical"
Expand Down Expand Up @@ -147,3 +148,133 @@ func prepareLDAPTestContainer(t *testing.T) (cleanup func(), retURL string) {

return
}

// TestLogin_TTLFieldAccepted validates that the login endpoint schema accepts
// the "ttl" field. This test cannot validate the actual TTL application because
// that requires successful SPNEGO authentication which needs a full Kerberos
// environment. The TTL parsing logic is thoroughly tested in TestParseTTL.
func TestLogin_TTLFieldAccepted(t *testing.T) {
b, storage := setupTestBackend(t)

cleanup, connURL := prepareLDAPTestContainer(t)
defer cleanup()

ldapReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: ldapConfPath,
Storage: storage,
Data: map[string]interface{}{
"url": connURL,
},
}

resp, err := b.HandleRequest(context.Background(), ldapReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err: %s resp: %#v\n", err, resp)
}

// Test various TTL values to ensure the field is properly accepted in the schema
testCases := []struct {
name string
ttl string
}{
{"valid 5 minutes", "5m"},
{"valid 1 hour", "1h"},
{"valid 30 seconds", "30s"},
{"valid complex duration", "1h30m"},
{"empty string", ""},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
data := map[string]interface{}{
"authorization": "",
"ttl": tc.ttl,
}

req := &logical.Request{
Operation: logical.UpdateOperation,
Path: "login",
Storage: storage,
Data: data,
Connection: &logical.Connection{
RemoteAddr: connURL,
},
}

resp, err = b.HandleRequest(context.Background(), req)
// Will get 401 due to missing SPNEGO auth, but TTL field should be accepted
// without any schema validation errors
if err == nil || resp == nil || resp.IsError() {
t.Fatalf("err: %s resp: %#v\n", err, resp)
}

if e, ok := err.(logical.HTTPCodedError); !ok || e.Code() != 401 {
t.Fatalf("expected 401 error for ttl=%q, got: %s resp: %#v\n", tc.ttl, err, resp)
}
})
}
}

// TestParseTTL thoroughly tests the TTL parsing logic. This is the core test
// for the dynamic TTL feature, validating all valid and invalid input formats.
func TestParseTTL(t *testing.T) {
tests := []struct {
name string
input string
expected time.Duration
wantErr bool
}{
// Empty and zero values
{"empty string", "", 0, false},
{"zero seconds", "0s", 0, false},
{"zero minutes", "0m", 0, false},
{"zero hours", "0h", 0, false},

// Valid simple durations
{"1 second", "1s", time.Second, false},
{"30 seconds", "30s", 30 * time.Second, false},
{"1 minute", "1m", time.Minute, false},
{"5 minutes", "5m", 5 * time.Minute, false},
{"1 hour", "1h", time.Hour, false},
{"2 hours", "2h", 2 * time.Hour, false},
{"24 hours", "24h", 24 * time.Hour, false},

// Valid complex durations
{"1h30m", "1h30m", 90 * time.Minute, false},
{"2h30m45s", "2h30m45s", 2*time.Hour + 30*time.Minute + 45*time.Second, false},
{"1m30s", "1m30s", 90 * time.Second, false},

// Subsecond durations
{"100 milliseconds", "100ms", 100 * time.Millisecond, false},
{"1 microsecond", "1us", time.Microsecond, false},
{"1 nanosecond", "1ns", time.Nanosecond, false},

// Negative durations (Go time.ParseDuration supports these)
{"negative 1 hour", "-1h", -time.Hour, false},
{"negative 30 seconds", "-30s", -30 * time.Second, false},

// Invalid formats
{"invalid word", "invalid", 0, true},
{"invalid unit", "1x", 0, true},
{"missing number", "m", 0, true},
{"just numbers", "123", 0, true},
{"spaces", "1 h", 0, true},
{"number with spaces", "1 hour", 0, true},
{"double unit", "1mm", 0, true},
{"special characters", "1h@", 0, true},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got, err := parseTTL(tc.input)
if (err != nil) != tc.wantErr {
t.Errorf("parseTTL(%q) error = %v, wantErr %v", tc.input, err, tc.wantErr)
return
}
if !tc.wantErr && got != tc.expected {
t.Errorf("parseTTL(%q) = %v, want %v", tc.input, got, tc.expected)
}
})
}
}
10 changes: 10 additions & 0 deletions test/acceptance/server-enterprise-basic-tests.bats
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ login_kerberos() {
docker exec -it "$DOMAIN_JOINED_CONTAINER" python /home/auth-check.py "$VAULT_CONTAINER" "${VAULT_NAMESPACE}"
}

test_ttl_feature() {
docker cp "${BATS_TEST_DIRNAME}"/ttl-test.py "$DOMAIN_JOINED_CONTAINER":/home
docker exec -it "$DOMAIN_JOINED_CONTAINER" python /home/ttl-test.py "$VAULT_CONTAINER" "${VAULT_NAMESPACE}"
}

assert_success() {
if [ ! "${status?}" -eq 0 ]; then
echo "${output}"
Expand Down Expand Up @@ -100,3 +105,8 @@ assert_success() {

[[ "${output?}" =~ ^Vault[[:space:]]token\:[[:space:]].+$ ]]
}

@test "auth/kerberos: dynamic TTL feature" {
run test_ttl_feature
assert_success
}
Loading