diff --git a/path_login.go b/path_login.go index a78cb48..8ab6eee 100644 --- a/path_login.go +++ b/path_login.go @@ -10,6 +10,7 @@ import ( "fmt" "net/http" "strings" + "time" "github.com/hashicorp/errwrap" "github.com/hashicorp/go-hclog" @@ -35,6 +36,11 @@ func (b *backend) pathLogin() *framework.Path { Type: framework.TypeString, Description: `SPNEGO Authorization header. Required.`, }, + "ttl": { + Type: framework.TypeString, + Description: "Optional custom TTL for the session token (e.g., '1m')", + Default: "", + }, }, Operations: map[logical.Operation]framework.OperationHandler{ logical.ReadOperation: &framework.PathOperation{ @@ -102,6 +108,7 @@ func (b *backend) pathLoginUpdate(ctx context.Context, req *logical.Request, d * } } + // Obtain the SPNEGO token from headers or JSON field authorizationString := "" authorizationHeaders := req.Headers["Authorization"] if len(authorizationHeaders) > 0 { @@ -163,7 +170,10 @@ func (b *backend) pathLoginUpdate(ctx context.Context, req *logical.Request, d * // therefore policies, from a separate directory. if ldapCfg.ConfigEntry.UPNDomain != "" && identity.Domain() != ldapCfg.ConfigEntry.UPNDomain { w.WriteHeader(400) - _, _ = w.Write([]byte(fmt.Sprintf("identity domain of %q doesn't match LDAP upndomain of %q", identity.Domain(), ldapCfg.ConfigEntry.UPNDomain))) + _, _ = w.Write([]byte(fmt.Sprintf( + "identity domain of %q doesn't match LDAP upndomain of %q", + identity.Domain(), ldapCfg.ConfigEntry.UPNDomain, + ))) return } authenticated = true @@ -230,7 +240,10 @@ func (b *backend) pathLoginUpdate(ctx context.Context, req *logical.Request, d * if err != nil { return nil, errwrap.Wrapf("unable to get user binddn: {{err}}", err) } - b.Logger().Debug("auth/ldap: User BindDN fetched", "username", identity.UserName(), "binddn", userBindDN) + b.Logger().Debug("auth/ldap: User BindDN fetched", + "username", identity.UserName(), + "binddn", userBindDN, + ) userDN, err := ldapClient.GetUserDN(ldapCfg.ConfigEntry, ldapConnection, userBindDN, username) if err != nil { @@ -241,7 +254,10 @@ func (b *backend) pathLoginUpdate(ctx context.Context, req *logical.Request, d * if err != nil { return nil, errwrap.Wrapf("unable to get ldap groups: {{err}}", err) } - b.Logger().Debug("auth/ldap: Groups fetched from server", "num_server_groups", len(ldapGroups), "server_groups", ldapGroups) + b.Logger().Debug("auth/ldap: Groups fetched from server", + "num_server_groups", len(ldapGroups), + "server_groups", ldapGroups, + ) var allGroups []string // Merge local and LDAP groups @@ -283,6 +299,18 @@ func (b *backend) pathLoginUpdate(ctx context.Context, req *logical.Request, d * Renewable: false, } + // If user set "ttl" in the JSON body, parse and apply it + if v, ok := d.GetOk("ttl"); ok { + ttlRaw := v.(string) + userTTL, err := parseTTL(ttlRaw) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf("invalid ttl format: %v", err)), nil + } + if userTTL > 0 { + auth.LeaseOptions.TTL = userTTL + } + } + // Combine our policies with the ones parsed from PopulateTokenAuth. if len(policies) > 0 { auth.Policies = append(auth.Policies, policies...) @@ -305,6 +333,16 @@ func (b *backend) pathLoginUpdate(ctx context.Context, req *logical.Request, d * }, nil } +// parseTTL parses a TTL string and returns the duration. +// Returns zero duration for empty string, error for invalid format. +func parseTTL(ttlRaw string) (time.Duration, error) { + if ttlRaw == "" { + return 0, nil + } + return time.ParseDuration(ttlRaw) +} + +// simpleResponseWriter is used internally to capture SPNEGO authentication responses type simpleResponseWriter struct { body []byte statusCode int diff --git a/path_login_test.go b/path_login_test.go index 4b71e4d..7cba5bc 100644 --- a/path_login_test.go +++ b/path_login_test.go @@ -8,6 +8,7 @@ import ( "fmt" "strings" "testing" + "time" "github.com/go-ldap/ldap/v3" "github.com/hashicorp/vault/sdk/logical" @@ -147,3 +148,133 @@ func prepareLDAPTestContainer(t *testing.T) (cleanup func(), retURL string) { return } + +// TestLogin_TTLFieldAccepted validates that the login endpoint schema accepts +// the "ttl" field. This test cannot validate the actual TTL application because +// that requires successful SPNEGO authentication which needs a full Kerberos +// environment. The TTL parsing logic is thoroughly tested in TestParseTTL. +func TestLogin_TTLFieldAccepted(t *testing.T) { + b, storage := setupTestBackend(t) + + cleanup, connURL := prepareLDAPTestContainer(t) + defer cleanup() + + ldapReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: ldapConfPath, + Storage: storage, + Data: map[string]interface{}{ + "url": connURL, + }, + } + + resp, err := b.HandleRequest(context.Background(), ldapReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err: %s resp: %#v\n", err, resp) + } + + // Test various TTL values to ensure the field is properly accepted in the schema + testCases := []struct { + name string + ttl string + }{ + {"valid 5 minutes", "5m"}, + {"valid 1 hour", "1h"}, + {"valid 30 seconds", "30s"}, + {"valid complex duration", "1h30m"}, + {"empty string", ""}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + data := map[string]interface{}{ + "authorization": "", + "ttl": tc.ttl, + } + + req := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "login", + Storage: storage, + Data: data, + Connection: &logical.Connection{ + RemoteAddr: connURL, + }, + } + + resp, err = b.HandleRequest(context.Background(), req) + // Will get 401 due to missing SPNEGO auth, but TTL field should be accepted + // without any schema validation errors + if err == nil || resp == nil || resp.IsError() { + t.Fatalf("err: %s resp: %#v\n", err, resp) + } + + if e, ok := err.(logical.HTTPCodedError); !ok || e.Code() != 401 { + t.Fatalf("expected 401 error for ttl=%q, got: %s resp: %#v\n", tc.ttl, err, resp) + } + }) + } +} + +// TestParseTTL thoroughly tests the TTL parsing logic. This is the core test +// for the dynamic TTL feature, validating all valid and invalid input formats. +func TestParseTTL(t *testing.T) { + tests := []struct { + name string + input string + expected time.Duration + wantErr bool + }{ + // Empty and zero values + {"empty string", "", 0, false}, + {"zero seconds", "0s", 0, false}, + {"zero minutes", "0m", 0, false}, + {"zero hours", "0h", 0, false}, + + // Valid simple durations + {"1 second", "1s", time.Second, false}, + {"30 seconds", "30s", 30 * time.Second, false}, + {"1 minute", "1m", time.Minute, false}, + {"5 minutes", "5m", 5 * time.Minute, false}, + {"1 hour", "1h", time.Hour, false}, + {"2 hours", "2h", 2 * time.Hour, false}, + {"24 hours", "24h", 24 * time.Hour, false}, + + // Valid complex durations + {"1h30m", "1h30m", 90 * time.Minute, false}, + {"2h30m45s", "2h30m45s", 2*time.Hour + 30*time.Minute + 45*time.Second, false}, + {"1m30s", "1m30s", 90 * time.Second, false}, + + // Subsecond durations + {"100 milliseconds", "100ms", 100 * time.Millisecond, false}, + {"1 microsecond", "1us", time.Microsecond, false}, + {"1 nanosecond", "1ns", time.Nanosecond, false}, + + // Negative durations (Go time.ParseDuration supports these) + {"negative 1 hour", "-1h", -time.Hour, false}, + {"negative 30 seconds", "-30s", -30 * time.Second, false}, + + // Invalid formats + {"invalid word", "invalid", 0, true}, + {"invalid unit", "1x", 0, true}, + {"missing number", "m", 0, true}, + {"just numbers", "123", 0, true}, + {"spaces", "1 h", 0, true}, + {"number with spaces", "1 hour", 0, true}, + {"double unit", "1mm", 0, true}, + {"special characters", "1h@", 0, true}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := parseTTL(tc.input) + if (err != nil) != tc.wantErr { + t.Errorf("parseTTL(%q) error = %v, wantErr %v", tc.input, err, tc.wantErr) + return + } + if !tc.wantErr && got != tc.expected { + t.Errorf("parseTTL(%q) = %v, want %v", tc.input, got, tc.expected) + } + }) + } +} diff --git a/test/acceptance/server-enterprise-basic-tests.bats b/test/acceptance/server-enterprise-basic-tests.bats index 11e1d22..6f7fa35 100644 --- a/test/acceptance/server-enterprise-basic-tests.bats +++ b/test/acceptance/server-enterprise-basic-tests.bats @@ -72,6 +72,11 @@ login_kerberos() { docker exec -it "$DOMAIN_JOINED_CONTAINER" python /home/auth-check.py "$VAULT_CONTAINER" "${VAULT_NAMESPACE}" } +test_ttl_feature() { + docker cp "${BATS_TEST_DIRNAME}"/ttl-test.py "$DOMAIN_JOINED_CONTAINER":/home + docker exec -it "$DOMAIN_JOINED_CONTAINER" python /home/ttl-test.py "$VAULT_CONTAINER" "${VAULT_NAMESPACE}" +} + assert_success() { if [ ! "${status?}" -eq 0 ]; then echo "${output}" @@ -100,3 +105,8 @@ assert_success() { [[ "${output?}" =~ ^Vault[[:space:]]token\:[[:space:]].+$ ]] } + +@test "auth/kerberos: dynamic TTL feature" { + run test_ttl_feature + assert_success +} diff --git a/test/acceptance/ttl-test.py b/test/acceptance/ttl-test.py new file mode 100644 index 0000000..8a36214 --- /dev/null +++ b/test/acceptance/ttl-test.py @@ -0,0 +1,151 @@ +# Copyright (c) HashiCorp, Inc. +# SPDX-License-Identifier: MPL-2.0 + +""" +Test script for validating the dynamic TTL feature in Kerberos auth. +Tests both with and without custom TTL values. +""" + +import json +import kerberos +import requests +import sys + +def get_kerberos_token(host): + """Get a Kerberos SPNEGO token for the given host.""" + service = "HTTP@{}".format(host) + rc, vc = kerberos.authGSSClientInit(service=service, mech_oid=kerberos.GSS_MECH_OID_SPNEGO) + kerberos.authGSSClientStep(vc, "") + return kerberos.authGSSClientResponse(vc) + +def login_with_ttl(host, namespace, ttl=None): + """ + Login to Vault with Kerberos auth, optionally specifying a TTL. + Returns the response JSON. + """ + kerberos_token = get_kerberos_token(host) + + url = "http://{}/v1/{}auth/kerberos/login".format(host, namespace) + headers = {'Authorization': 'Negotiate ' + kerberos_token} + + if ttl: + # Send TTL in the JSON body + response = requests.post(url, headers=headers, json={'ttl': ttl}) + else: + response = requests.post(url, headers=headers) + + if response.status_code != 200: + print("Login failed with status {}: {}".format(response.status_code, response.text)) + return None + + return response.json() + +def test_default_ttl(host, namespace): + """Test login without custom TTL.""" + print("Testing login without custom TTL...") + result = login_with_ttl(host, namespace) + if not result: + return False + + auth = result.get('auth', {}) + token = auth.get('client_token') + lease_duration = auth.get('lease_duration', 0) + + if not token: + print("FAIL: No client token received") + return False + + print("SUCCESS: Got token with default lease_duration={}s".format(lease_duration)) + return True + +def test_custom_ttl(host, namespace, ttl_str, expected_seconds): + """Test login with custom TTL.""" + print("Testing login with TTL='{}'...".format(ttl_str)) + result = login_with_ttl(host, namespace, ttl=ttl_str) + if not result: + return False + + auth = result.get('auth', {}) + token = auth.get('client_token') + lease_duration = auth.get('lease_duration', 0) + + if not token: + print("FAIL: No client token received") + return False + + if lease_duration != expected_seconds: + print("FAIL: Expected lease_duration={}s, got {}s".format(expected_seconds, lease_duration)) + return False + + print("SUCCESS: Got token with lease_duration={}s (expected {}s)".format(lease_duration, expected_seconds)) + return True + +def test_invalid_ttl(host, namespace): + """Test that invalid TTL returns an error.""" + print("Testing login with invalid TTL='invalid'...") + kerberos_token = get_kerberos_token(host) + + url = "http://{}/v1/{}auth/kerberos/login".format(host, namespace) + headers = {'Authorization': 'Negotiate ' + kerberos_token} + response = requests.post(url, headers=headers, json={'ttl': 'invalid'}) + + if response.status_code == 400: + error_msg = response.json().get('errors', [''])[0] + if 'invalid ttl format' in error_msg.lower(): + print("SUCCESS: Invalid TTL correctly rejected with error: {}".format(error_msg)) + return True + + print("FAIL: Expected 400 error for invalid TTL, got status {}".format(response.status_code)) + return False + +def main(): + if len(sys.argv) < 3: + print("Usage: {} ".format(sys.argv[0])) + sys.exit(1) + + prefix = sys.argv[1] + namespace = sys.argv[2] + host = prefix + ".matrix.lan:8200" + + print("=" * 60) + print("TTL Feature Tests") + print("Host: {}".format(host)) + print("Namespace: {}".format(namespace)) + print("=" * 60) + + results = [] + + # Test 1: Default TTL + results.append(("Default TTL", test_default_ttl(host, namespace))) + + # Test 2: Custom TTL of 5 minutes (300 seconds) + results.append(("Custom TTL 5m", test_custom_ttl(host, namespace, "5m", 300))) + + # Test 3: Custom TTL of 1 hour (3600 seconds) + results.append(("Custom TTL 1h", test_custom_ttl(host, namespace, "1h", 3600))) + + # Test 4: Custom TTL of 30 seconds + results.append(("Custom TTL 30s", test_custom_ttl(host, namespace, "30s", 30))) + + # Test 5: Invalid TTL + results.append(("Invalid TTL", test_invalid_ttl(host, namespace))) + + print("=" * 60) + print("Results Summary:") + all_passed = True + for name, passed in results: + status = "PASS" if passed else "FAIL" + print(" {}: {}".format(name, status)) + if not passed: + all_passed = False + print("=" * 60) + + if all_passed: + print("All TTL tests passed!") + sys.exit(0) + else: + print("Some TTL tests failed!") + sys.exit(1) + +if __name__ == "__main__": + main()