Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmd/authd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"os"
"os/signal"
"strconv"
"sync"
"syscall"

Expand All @@ -30,6 +31,7 @@ type app interface {
}

func run(a app) int {
os.Setenv("AUTHD_PID", strconv.FormatInt(int64(os.Getpid()), 10))
defer installSignalHandler(a)()

if err := a.Run(); err != nil {
Expand Down
15 changes: 6 additions & 9 deletions nss/integration-tests/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,22 @@ import (
"io"
"os"
"os/exec"
"path/filepath"
"slices"
"testing"
)

// getentOutputForLib returns the specific part for the nss command for the authd service.
// It uses the locally build authd nss module for the integration tests.
func getentOutputForLib(t *testing.T, libPath, socketPath string, rustCovEnv []string, shouldPreCheck bool, cmds ...string) (got string, exitCode int) {
func getentOutputForLib(t *testing.T, socketPath string, env []string, shouldPreCheck bool, cmds ...string) (got string, exitCode int) {
t.Helper()

// #nosec:G204 - we control the command arguments in tests
cmds = append(cmds, "--service", "authd")
cmd := exec.Command("getent", cmds...)
cmd.Env = append(cmd.Env,
"AUTHD_NSS_INFO=stderr",
// NSS needs both LD_PRELOAD and LD_LIBRARY_PATH to load the module library
fmt.Sprintf("LD_PRELOAD=%s:%s", libPath, os.Getenv("LD_PRELOAD")),
fmt.Sprintf("LD_LIBRARY_PATH=%s:%s", filepath.Dir(libPath), os.Getenv("LD_LIBRARY_PATH")),
)
cmd.Env = append(cmd.Env, rustCovEnv...)
cmd.Env = slices.Clone(env)

// Set the PID to to self, so that we can verify that it won't work for all.
cmd.Env = append(cmd.Env, fmt.Sprintf("AUTHD_PID=%d", os.Getpid()))

if socketPath != "" {
cmd.Env = append(cmd.Env, fmt.Sprintf("AUTHD_NSS_SOCKET=%s", socketPath))
Expand Down
133 changes: 129 additions & 4 deletions nss/integration-tests/integration_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package nss_test

import (
"bytes"
"context"
"fmt"
"log"
"os"
"os/exec"
"os/user"
"path/filepath"
"strings"
"testing"
Expand All @@ -13,6 +17,7 @@ import (
"github.com/ubuntu/authd/internal/testutils"
"github.com/ubuntu/authd/internal/testutils/golden"
localgroupstestutils "github.com/ubuntu/authd/internal/users/localentries/testutils"
"gopkg.in/yaml.v3"
)

var daemonPath string
Expand All @@ -26,12 +31,21 @@ func TestIntegration(t *testing.T) {
libPath, rustCovEnv := testutils.BuildRustNSSLib(t, false, "should_pre_check_env")

// Create a default daemon to use for most test cases.
defaultSocket := filepath.Join(os.TempDir(), "nss-integration-tests.sock")
defaultSocket := filepath.Join(t.TempDir(), "nss.sock")
defaultDbState := "multiple_users_and_groups"
defaultOutputPath := filepath.Join(filepath.Dir(daemonPath), "gpasswd.output")
defaultGroupsFilePath := filepath.Join(testutils.TestFamilyPath(t), "gpasswd.group")

nssLibraryEnv := append(rustCovEnv,
"AUTHD_NSS_INFO=stderr",
// NSS needs both LD_PRELOAD and LD_LIBRARY_PATH to load the module library
fmt.Sprintf("LD_PRELOAD=%s:%s", libPath, os.Getenv("LD_PRELOAD")),
fmt.Sprintf("LD_LIBRARY_PATH=%s:%s", filepath.Dir(libPath), os.Getenv("LD_LIBRARY_PATH")),
)

env := append(localgroupstestutils.AuthdIntegrationTestsEnvWithGpasswdMock(t, defaultOutputPath, defaultGroupsFilePath), "AUTHD_INTEGRATIONTESTS_CURRENT_USER_AS_ROOT=1")
env = append(env, nssLibraryEnv...)
env = append(env, fmt.Sprintf("AUTHD_NSS_SOCKET=%s", defaultSocket))
ctx, cancel := context.WithCancel(context.Background())
_, stopped := testutils.RunDaemon(ctx, t, daemonPath,
testutils.WithSocketPath(defaultSocket),
Expand Down Expand Up @@ -118,12 +132,17 @@ func TestIntegration(t *testing.T) {
outPath := filepath.Join(t.TempDir(), "gpasswd.output")
groupsFilePath := filepath.Join("testdata", "empty.group")

socketPath = filepath.Join(t.TempDir(), "nss.sock")

var daemonStopped chan struct{}
ctx, cancel := context.WithCancel(context.Background())
env := localgroupstestutils.AuthdIntegrationTestsEnvWithGpasswdMock(t, outPath, groupsFilePath)
socketPath, daemonStopped = testutils.RunDaemon(ctx, t, daemonPath,
env = append(env, nssLibraryEnv...)
env = append(env, fmt.Sprintf("AUTHD_NSS_SOCKET=%s", socketPath))
_, daemonStopped = testutils.RunDaemon(ctx, t, daemonPath,
testutils.WithPreviousDBState(tc.dbState),
testutils.WithEnvironment(env...),
testutils.WithSocketPath(socketPath),
)
t.Cleanup(func() {
cancel()
Expand All @@ -136,7 +155,7 @@ func TestIntegration(t *testing.T) {
cmds = append(cmds, tc.key)
}

got, status := getentOutputForLib(t, libPath, socketPath, rustCovEnv, tc.shouldPreCheck, cmds...)
got, status := getentOutputForLib(t, socketPath, nssLibraryEnv, tc.shouldPreCheck, cmds...)
require.Equal(t, tc.wantStatus, status, "Expected status %d, but got %d", tc.wantStatus, status)

if tc.shouldPreCheck && tc.getentDB == "passwd" {
Expand Down Expand Up @@ -164,12 +183,118 @@ func TestIntegration(t *testing.T) {

// This is to check that some cache tasks, such as cleaning a corrupted database, work as expected.
if tc.wantSecondCall {
got, status := getentOutputForLib(t, libPath, socketPath, rustCovEnv, tc.shouldPreCheck, cmds...)
got, status := getentOutputForLib(t, socketPath, nssLibraryEnv, tc.shouldPreCheck, cmds...)
require.NotEqual(t, codeNotFound, status, "Expected no error, but got %v", status)
require.Empty(t, got, "Expected empty output, but got %q", got)
}
})
}

runPidAbuser := func(action, arg string) []byte {
require.NotEmpty(t, action, "Setup: action should not be empty")

// #nosec:G204 - we control the command arguments in tests
cmd := exec.Command("go", "run")
if testutils.CoverDirForTests() != "" {
// -cover is a "positional flag", so it needs to come right after the "build" command.
cmd.Args = append(cmd.Args, "-cover")
cmd.Env = testutils.AppendCovEnv(env)
}
if testutils.IsRace() {
cmd.Args = append(cmd.Args, "-race")
}
cmd.Env = append(cmd.Env, nssLibraryEnv...)
cmd.Env = append(cmd.Env,
fmt.Sprintf("AUTHD_NSS_SOCKET=%s", defaultSocket),
"ACTION="+action,
"ACTION_ARG="+arg,
)
cmd.Env = append(cmd.Env, os.Environ()...)

cmd.Dir = "pid_abuser"
cmd.Args = append(cmd.Args, "./")
var stdout bytes.Buffer
var stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
require.NoError(t, err, "Could not run PID abuser: %s, %s",
stdout.String(), stderr.String())
t.Logf("STDOUT:\n%s", stdout.String())
t.Logf("STDERR:\n%s", stderr.String())
return stdout.Bytes()
}

t.Run("Simulate_running_as_authd", func(t *testing.T) {
tests := map[string]struct {
action string
arg string

want any
}{
"Lookups_user": {
action: "lookup_user",
arg: "user1",
want: user.User{
Uid: "1111",
Gid: "11111",
Username: "user1",
Name: "User1 gecos\nOn multiple lines",
HomeDir: "/home/user1",
},
},
"Lookups_group": {
action: "lookup_group",
arg: "group1",
want: user.Group{Gid: "11111", Name: "group1"},
},
"Lookups_uid": {
action: "lookup_uid",
arg: "1111",
want: user.User{
Uid: "1111",
Gid: "11111",
Username: "user1",
Name: "User1 gecos\nOn multiple lines",
HomeDir: "/home/user1",
},
},
"Lookups_gid": {
action: "lookup_gid",
arg: "11111",
want: user.Group{Gid: "11111", Name: "group1"},
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
t.Parallel()

ret := runPidAbuser(tc.action, tc.arg)

switch action, _ := strings.CutPrefix(tc.action, "lookup_"); action {
case "user":
fallthrough
case "uid":
u := unmarshalYAML[user.User](t, ret)
require.Equal(t, tc.want, u, "User does not match")
case "group":
fallthrough
case "gid":
g := unmarshalYAML[user.Group](t, ret)
require.Equal(t, tc.want, g, "Group does not match")
}
})
}
})
}

func unmarshalYAML[T any](t *testing.T, yml []byte) T {
t.Helper()

var val T
err := yaml.Unmarshal(yml, &val)
require.NoError(t, err, "Unmarshalling failed:\n%q", yml)
return val
}

func TestMockgpasswd(t *testing.T) {
Expand Down
49 changes: 49 additions & 0 deletions nss/integration-tests/pid_abuser/pidabuser.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// TiCS: disabled // This file is a test helper.

// Package main is the package for the pid abuser test tool.
package main

import (
"fmt"
"os"
"os/user"
"strconv"

"gopkg.in/yaml.v3"
)

func main() {
os.Setenv("AUTHD_PID", strconv.FormatInt(int64(os.Getpid()), 10))

action := os.Getenv("ACTION")
actionArg := os.Getenv("ACTION_ARG")

switch action {
case "lookup_user":
outputAsYAMLOrFail(user.Lookup(actionArg))

case "lookup_group":
outputAsYAMLOrFail(user.LookupGroup(actionArg))

case "lookup_uid":
outputAsYAMLOrFail(user.LookupId(actionArg))

case "lookup_gid":
outputAsYAMLOrFail(user.LookupGroupId(actionArg))

default:
panic("Invalid action " + action)
}
}

func outputAsYAMLOrFail[T any](val T, err error) {
if err != nil {
panic(err)
}

out, err := yaml.Marshal(val)
if err != nil {
panic(err)
}
fmt.Printf("%s\n", out)
}
70 changes: 66 additions & 4 deletions nss/src/client/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use authd::user_service_client::UserServiceClient;
use hyper_util::rt::TokioIo;
use std::error::Error;
use std::sync::OnceLock;
use tokio::net::UnixStream;
use tonic::transport::{Channel, Endpoint, Uri};
use tower::service_fn;
Expand All @@ -11,18 +12,79 @@ pub mod authd {
tonic::include_proto!("authd");
}

const AUTHD_PID_ENV_VAR: &str = "AUTHD_PID";

/// new_client creates a new client connection to the gRPC server or returns an active one.
pub async fn new_client() -> Result<UserServiceClient<Channel>, Box<dyn Error>> {
info!("Connecting to authd on {}...", super::socket_path());

// Cache for self-check result.
static AUTHD_PROCESS_CHECK: OnceLock<bool> = OnceLock::new();

let connector = service_fn(|_: Uri| async {
let stream = UnixStream::connect(super::socket_path()).await?;

if *AUTHD_PROCESS_CHECK.get_or_init(|| check_is_authd_process(&stream)) {
info!("Module loaded by authd itself: ignoring the connection");

return Err(std::io::Error::new(
std::io::ErrorKind::Unsupported,
"Ignoring connection from authd to authd itself",
));
}

Ok::<_, std::io::Error>(TokioIo::new(stream))
});

// The URL must have a valid format, even though we don't use it.
let ch = Endpoint::try_from("https://not-used:404")?
.connect_timeout(CONNECTION_TIMEOUT)
.connect_with_connector(service_fn(|_: Uri| async {
let stream = UnixStream::connect(super::socket_path()).await?;
Ok::<_, std::io::Error>(TokioIo::new(stream))
}))
.connect_with_connector(connector)
.await?;

Ok(UserServiceClient::new(ch))
}

fn check_is_authd_process(stream: &UnixStream) -> bool {
// Check if we've been launched with a AUTHD_PID env variable set with
// a numeric value. If these checks fail, we can just continue with the
// connection as we were. As for sure the library has not been loaded
// by authd.
let Ok(authd_pid) = std::env::var(AUTHD_PID_ENV_VAR) else {
return false;
};
info!(
"authd module launched with {}={}",
AUTHD_PID_ENV_VAR, authd_pid
);
let Ok(authd_pid_value) = authd_pid.parse::<u32>() else {
return false;
};

let current_pid = std::process::id();
info!("current PID is {}", current_pid);
if current_pid != authd_pid_value {
return false;
}

// Get the peer credentials, and check if the server PIDs matches the
// AUTHD_PID, an if it does, we can avoid any connection since we're
// sure that we have been loaded by authd (and not by another crafted
// client to act like it, to ignore the authd module)
let Ok(peer_cred) = stream.peer_cred() else {
return false;
};
let Some(peer_pid) = peer_cred.pid() else {
return false;
};

info!(
"authd socket is provided by PID {} (expecting {})",
peer_pid, authd_pid
);
if authd_pid_value != peer_pid.try_into().unwrap() {
return false;
}

return true;
}
Loading
Loading